ggml-metal.metal 395 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073507450755076507750785079508050815082508350845085508650875088508950905091509250935094509550965097509850995100510151025103510451055106510751085109511051115112511351145115511651175118511951205121512251235124512551265127512851295130513151325133513451355136513751385139514051415142514351445145514651475148514951505151515251535154515551565157515851595160516151625163516451655166516751685169517051715172517351745175517651775178517951805181518251835184518551865187518851895190519151925193519451955196519751985199520052015202520352045205520652075208520952105211521252135214521552165217521852195220522152225223522452255226522752285229523052315232523352345235523652375238523952405241524252435244524552465247524852495250525152525253525452555256525752585259526052615262526352645265526652675268526952705271527252735274527552765277527852795280528152825283528452855286528752885289529052915292529352945295529652975298529953005301530253035304530553065307530853095310531153125313531453155316531753185319532053215322532353245325532653275328532953305331533253335334533553365337533853395340534153425343534453455346534753485349535053515352535353545355535653575358535953605361536253635364536553665367536853695370537153725373537453755376537753785379538053815382538353845385538653875388538953905391539253935394539553965397539853995400540154025403540454055406540754085409541054115412541354145415541654175418541954205421542254235424542554265427542854295430543154325433543454355436543754385439544054415442544354445445544654475448544954505451545254535454545554565457545854595460546154625463546454655466546754685469547054715472547354745475547654775478547954805481548254835484548554865487548854895490549154925493549454955496549754985499550055015502550355045505550655075508550955105511551255135514551555165517551855195520552155225523552455255526552755285529553055315532553355345535553655375538553955405541554255435544554555465547554855495550555155525553555455555556555755585559556055615562556355645565556655675568556955705571557255735574557555765577557855795580558155825583558455855586558755885589559055915592559355945595559655975598559956005601560256035604560556065607560856095610561156125613561456155616561756185619562056215622562356245625562656275628562956305631563256335634563556365637563856395640564156425643564456455646564756485649565056515652565356545655565656575658565956605661566256635664566556665667566856695670567156725673567456755676567756785679568056815682568356845685568656875688568956905691569256935694569556965697569856995700570157025703570457055706570757085709571057115712571357145715571657175718571957205721572257235724572557265727572857295730573157325733573457355736573757385739574057415742574357445745574657475748574957505751575257535754575557565757575857595760576157625763576457655766576757685769577057715772577357745775577657775778577957805781578257835784578557865787578857895790579157925793579457955796579757985799580058015802580358045805580658075808580958105811581258135814581558165817581858195820582158225823582458255826582758285829583058315832583358345835583658375838583958405841584258435844584558465847584858495850585158525853585458555856585758585859586058615862586358645865586658675868586958705871587258735874587558765877587858795880588158825883588458855886588758885889589058915892589358945895589658975898589959005901590259035904590559065907590859095910591159125913591459155916591759185919592059215922592359245925592659275928592959305931593259335934593559365937593859395940594159425943594459455946594759485949595059515952595359545955595659575958595959605961596259635964596559665967596859695970597159725973597459755976597759785979598059815982598359845985598659875988598959905991599259935994599559965997599859996000600160026003600460056006600760086009601060116012601360146015601660176018601960206021602260236024602560266027602860296030603160326033603460356036603760386039604060416042604360446045604660476048604960506051605260536054605560566057605860596060606160626063606460656066606760686069607060716072607360746075607660776078607960806081608260836084608560866087608860896090609160926093609460956096609760986099610061016102610361046105610661076108610961106111611261136114611561166117611861196120612161226123612461256126612761286129613061316132613361346135613661376138613961406141614261436144614561466147614861496150615161526153615461556156615761586159616061616162616361646165616661676168616961706171617261736174617561766177617861796180618161826183618461856186618761886189619061916192619361946195619661976198619962006201620262036204620562066207620862096210621162126213621462156216621762186219622062216222622362246225622662276228622962306231623262336234623562366237623862396240624162426243624462456246624762486249625062516252625362546255625662576258625962606261626262636264626562666267626862696270627162726273627462756276627762786279628062816282628362846285628662876288628962906291629262936294629562966297629862996300630163026303630463056306630763086309631063116312631363146315631663176318631963206321632263236324632563266327632863296330633163326333633463356336633763386339634063416342634363446345634663476348634963506351635263536354635563566357635863596360636163626363636463656366636763686369637063716372637363746375637663776378637963806381638263836384638563866387638863896390639163926393639463956396639763986399640064016402640364046405640664076408640964106411641264136414641564166417641864196420642164226423642464256426642764286429643064316432643364346435643664376438643964406441644264436444644564466447644864496450645164526453645464556456645764586459646064616462646364646465646664676468646964706471647264736474647564766477647864796480648164826483648464856486648764886489649064916492649364946495649664976498649965006501650265036504650565066507650865096510651165126513651465156516651765186519652065216522652365246525652665276528652965306531653265336534653565366537653865396540654165426543654465456546654765486549655065516552655365546555655665576558655965606561656265636564656565666567656865696570657165726573657465756576657765786579658065816582658365846585658665876588658965906591659265936594659565966597659865996600660166026603660466056606660766086609661066116612661366146615661666176618661966206621662266236624662566266627662866296630663166326633663466356636663766386639664066416642664366446645664666476648664966506651665266536654665566566657665866596660666166626663666466656666666766686669667066716672667366746675667666776678667966806681668266836684668566866687668866896690669166926693669466956696669766986699670067016702670367046705670667076708670967106711671267136714671567166717671867196720672167226723672467256726672767286729673067316732673367346735673667376738673967406741674267436744674567466747674867496750675167526753675467556756675767586759676067616762676367646765676667676768676967706771677267736774677567766777677867796780678167826783678467856786678767886789679067916792679367946795679667976798679968006801680268036804680568066807680868096810681168126813681468156816681768186819682068216822682368246825682668276828682968306831683268336834683568366837683868396840684168426843684468456846684768486849685068516852685368546855685668576858685968606861686268636864686568666867686868696870687168726873687468756876687768786879688068816882688368846885688668876888688968906891689268936894689568966897689868996900690169026903690469056906690769086909691069116912691369146915691669176918691969206921692269236924692569266927692869296930693169326933693469356936693769386939694069416942694369446945694669476948694969506951695269536954695569566957695869596960696169626963696469656966696769686969697069716972697369746975697669776978697969806981698269836984698569866987698869896990699169926993699469956996699769986999700070017002700370047005700670077008700970107011701270137014701570167017701870197020702170227023702470257026702770287029703070317032703370347035703670377038703970407041704270437044704570467047704870497050705170527053705470557056705770587059706070617062706370647065706670677068706970707071707270737074707570767077707870797080708170827083708470857086708770887089709070917092709370947095709670977098709971007101710271037104710571067107710871097110711171127113711471157116711771187119712071217122712371247125712671277128712971307131713271337134713571367137713871397140714171427143714471457146714771487149715071517152715371547155715671577158715971607161716271637164716571667167716871697170717171727173717471757176717771787179718071817182718371847185718671877188718971907191719271937194719571967197719871997200720172027203720472057206720772087209721072117212721372147215721672177218721972207221722272237224722572267227722872297230723172327233723472357236723772387239724072417242724372447245724672477248724972507251725272537254725572567257725872597260726172627263726472657266726772687269727072717272727372747275727672777278727972807281728272837284728572867287728872897290729172927293729472957296729772987299730073017302730373047305730673077308730973107311731273137314731573167317731873197320732173227323732473257326732773287329733073317332733373347335733673377338733973407341734273437344734573467347734873497350735173527353735473557356735773587359736073617362736373647365736673677368736973707371737273737374737573767377737873797380738173827383738473857386738773887389739073917392739373947395739673977398739974007401740274037404740574067407740874097410741174127413741474157416741774187419742074217422742374247425742674277428742974307431743274337434743574367437743874397440744174427443744474457446744774487449745074517452745374547455745674577458745974607461746274637464746574667467746874697470747174727473747474757476747774787479748074817482748374847485748674877488748974907491749274937494749574967497749874997500750175027503750475057506750775087509751075117512751375147515751675177518751975207521752275237524752575267527752875297530753175327533753475357536753775387539754075417542754375447545754675477548754975507551755275537554755575567557755875597560756175627563756475657566756775687569757075717572757375747575757675777578757975807581758275837584758575867587758875897590759175927593759475957596759775987599760076017602760376047605760676077608760976107611761276137614761576167617761876197620762176227623762476257626762776287629763076317632763376347635763676377638763976407641764276437644764576467647764876497650765176527653765476557656765776587659766076617662766376647665766676677668766976707671767276737674767576767677767876797680768176827683768476857686768776887689769076917692769376947695769676977698769977007701770277037704770577067707770877097710771177127713771477157716771777187719772077217722772377247725772677277728772977307731773277337734773577367737773877397740774177427743774477457746774777487749775077517752775377547755775677577758775977607761776277637764776577667767776877697770777177727773777477757776777777787779778077817782778377847785778677877788778977907791779277937794779577967797779877997800780178027803780478057806780778087809781078117812781378147815781678177818781978207821782278237824782578267827782878297830783178327833783478357836783778387839784078417842784378447845784678477848784978507851785278537854785578567857785878597860786178627863786478657866786778687869787078717872787378747875787678777878787978807881788278837884788578867887788878897890789178927893789478957896789778987899790079017902790379047905790679077908790979107911791279137914791579167917791879197920792179227923792479257926792779287929793079317932793379347935793679377938793979407941794279437944794579467947794879497950795179527953795479557956795779587959796079617962796379647965796679677968796979707971797279737974797579767977797879797980798179827983798479857986798779887989799079917992799379947995799679977998799980008001800280038004800580068007800880098010801180128013801480158016801780188019802080218022802380248025802680278028802980308031803280338034803580368037803880398040804180428043804480458046804780488049805080518052805380548055805680578058805980608061806280638064806580668067806880698070807180728073807480758076807780788079808080818082808380848085808680878088808980908091809280938094809580968097809880998100810181028103810481058106810781088109811081118112811381148115811681178118811981208121812281238124812581268127812881298130813181328133813481358136813781388139814081418142814381448145814681478148814981508151815281538154815581568157815881598160816181628163816481658166816781688169817081718172817381748175817681778178817981808181818281838184818581868187818881898190819181928193819481958196819781988199820082018202820382048205820682078208820982108211821282138214821582168217821882198220822182228223822482258226822782288229823082318232823382348235823682378238823982408241824282438244824582468247824882498250825182528253825482558256825782588259826082618262826382648265826682678268826982708271827282738274827582768277827882798280828182828283828482858286828782888289829082918292829382948295829682978298829983008301830283038304830583068307830883098310831183128313831483158316831783188319832083218322832383248325832683278328832983308331833283338334833583368337833883398340834183428343834483458346834783488349835083518352835383548355835683578358835983608361836283638364836583668367836883698370837183728373837483758376837783788379838083818382838383848385838683878388838983908391839283938394839583968397839883998400840184028403840484058406840784088409841084118412841384148415841684178418841984208421842284238424842584268427842884298430843184328433843484358436843784388439844084418442844384448445844684478448844984508451845284538454845584568457845884598460846184628463846484658466846784688469847084718472847384748475847684778478847984808481848284838484848584868487848884898490849184928493849484958496849784988499850085018502850385048505850685078508850985108511851285138514851585168517851885198520852185228523852485258526852785288529853085318532853385348535853685378538853985408541854285438544854585468547854885498550855185528553855485558556855785588559856085618562856385648565856685678568856985708571857285738574857585768577857885798580858185828583858485858586858785888589859085918592859385948595859685978598859986008601860286038604860586068607860886098610861186128613861486158616861786188619862086218622862386248625862686278628862986308631863286338634863586368637863886398640864186428643864486458646864786488649865086518652865386548655865686578658865986608661866286638664866586668667866886698670867186728673867486758676867786788679868086818682868386848685868686878688868986908691869286938694869586968697869886998700870187028703870487058706870787088709871087118712871387148715871687178718871987208721872287238724872587268727872887298730873187328733873487358736873787388739874087418742874387448745874687478748874987508751875287538754875587568757875887598760876187628763876487658766876787688769877087718772877387748775877687778778877987808781878287838784878587868787878887898790879187928793879487958796879787988799880088018802880388048805880688078808880988108811881288138814881588168817881888198820882188228823882488258826882788288829883088318832883388348835883688378838883988408841884288438844884588468847884888498850885188528853885488558856885788588859886088618862886388648865886688678868886988708871887288738874887588768877887888798880888188828883888488858886888788888889889088918892889388948895889688978898889989008901890289038904890589068907890889098910891189128913891489158916891789188919892089218922892389248925892689278928892989308931893289338934893589368937893889398940894189428943894489458946894789488949895089518952895389548955895689578958895989608961896289638964896589668967896889698970897189728973897489758976897789788979898089818982898389848985898689878988898989908991899289938994899589968997899889999000900190029003900490059006900790089009901090119012901390149015901690179018901990209021902290239024902590269027902890299030903190329033903490359036903790389039904090419042904390449045904690479048904990509051905290539054905590569057905890599060906190629063906490659066906790689069907090719072907390749075907690779078907990809081908290839084908590869087908890899090909190929093909490959096909790989099910091019102910391049105910691079108910991109111911291139114911591169117911891199120912191229123912491259126912791289129913091319132913391349135913691379138913991409141914291439144914591469147914891499150915191529153915491559156915791589159916091619162916391649165916691679168916991709171917291739174917591769177917891799180918191829183918491859186918791889189919091919192919391949195919691979198919992009201920292039204920592069207920892099210921192129213921492159216921792189219922092219222922392249225922692279228922992309231923292339234923592369237923892399240924192429243924492459246924792489249925092519252925392549255925692579258925992609261926292639264926592669267926892699270927192729273927492759276927792789279928092819282928392849285928692879288928992909291929292939294929592969297929892999300930193029303930493059306930793089309931093119312931393149315931693179318931993209321932293239324932593269327932893299330933193329333933493359336933793389339934093419342934393449345934693479348934993509351935293539354935593569357935893599360936193629363936493659366936793689369937093719372937393749375937693779378937993809381938293839384938593869387938893899390939193929393939493959396939793989399940094019402940394049405940694079408940994109411941294139414941594169417941894199420942194229423942494259426942794289429943094319432943394349435943694379438943994409441944294439444944594469447944894499450945194529453945494559456945794589459946094619462946394649465946694679468946994709471947294739474947594769477947894799480948194829483948494859486948794889489949094919492949394949495949694979498949995009501950295039504950595069507950895099510951195129513951495159516951795189519952095219522952395249525952695279528952995309531953295339534953595369537953895399540954195429543954495459546954795489549955095519552955395549555955695579558955995609561956295639564956595669567956895699570957195729573957495759576957795789579958095819582958395849585958695879588958995909591959295939594959595969597959895999600960196029603960496059606960796089609961096119612961396149615961696179618961996209621962296239624962596269627962896299630963196329633963496359636963796389639964096419642964396449645964696479648964996509651965296539654965596569657965896599660966196629663966496659666966796689669967096719672967396749675967696779678967996809681
  1. #define GGML_COMMON_DECL_METAL
  2. #define GGML_COMMON_IMPL_METAL
  3. #if defined(GGML_METAL_EMBED_LIBRARY)
  4. __embed_ggml-common.h__
  5. #else
  6. #include "ggml-common.h"
  7. #endif
  8. #include "ggml-metal-impl.h"
  9. #include <metal_stdlib>
  10. #ifdef GGML_METAL_HAS_TENSOR
  11. #include <metal_tensor>
  12. #include <MetalPerformancePrimitives/MetalPerformancePrimitives.h>
  13. #endif
  14. using namespace metal;
  15. #define MAX(x, y) ((x) > (y) ? (x) : (y))
  16. #define MIN(x, y) ((x) < (y) ? (x) : (y))
  17. #define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
  18. #define PAD2(x, n) (((x) + (n) - 1) & ~((n) - 1))
  19. #define FOR_UNROLL(x) _Pragma("clang loop unroll(full)") for (x)
  20. #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
  21. // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
  22. //
  23. // cmd:
  24. // .../usr/bin/metal -dM -E -c ggml/src/ggml-metal/ggml-metal.metal
  25. // .../usr/bin/metal -dM -E -c -target air64-apple-ios14.0 ggml/src/ggml-metal/ggml-metal.metal
  26. //
  27. #if __METAL_VERSION__ < 310 && defined(GGML_METAL_HAS_BF16)
  28. #undef GGML_METAL_HAS_BF16
  29. #endif
  30. #if defined(GGML_METAL_HAS_BF16)
  31. typedef matrix<bfloat, 4, 4> bfloat4x4;
  32. typedef matrix<bfloat, 2, 4> bfloat2x4;
  33. #endif
  34. constexpr constant static float kvalues_iq4nl_f[16] = {
  35. -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
  36. };
  37. constexpr constant static float kvalues_mxfp4_f[16] = {
  38. 0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f
  39. };
  40. static inline int best_index_int8(int n, constant float * val, float x) {
  41. if (x <= val[0]) return 0;
  42. if (x >= val[n-1]) return n-1;
  43. int ml = 0, mu = n-1;
  44. while (mu-ml > 1) {
  45. int mav = (ml+mu)/2;
  46. if (x < val[mav]) mu = mav; else ml = mav;
  47. }
  48. return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
  49. }
  50. static inline float e8m0_to_fp32(uint8_t x) {
  51. uint32_t bits;
  52. if (x == 0) {
  53. bits = 0x00400000;
  54. } else {
  55. bits = (uint32_t) x << 23;
  56. }
  57. return as_type<float>(bits);
  58. }
  59. static inline float dot(float x, float y) {
  60. return x*y;
  61. }
  62. // NOTE: this is not dequantizing - we are simply fitting the template
  63. template <typename type4x4>
  64. void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
  65. reg = (type4x4)(*src);
  66. }
  67. template <typename type4>
  68. void dequantize_f32_t4(device const float4 * src, short il, thread type4 & reg) {
  69. reg = (type4)(*src);
  70. }
  71. template <typename type4x4>
  72. void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
  73. reg = (type4x4)(*src);
  74. }
  75. template <typename type4>
  76. void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
  77. reg = (type4)(*(src));
  78. }
  79. #if defined(GGML_METAL_HAS_BF16)
  80. template <typename type4x4>
  81. void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
  82. reg = (type4x4)(*src);
  83. }
  84. template <typename type4>
  85. void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg) {
  86. reg = (type4)(*(src));
  87. }
  88. #endif
  89. template <typename type4x4>
  90. void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
  91. device const uint16_t * qs = ((device const uint16_t *)xb + 1);
  92. const float d1 = il ? (xb->d / 16.h) : xb->d;
  93. const float d2 = d1 / 256.f;
  94. const float md = -8.h * xb->d;
  95. const ushort mask0 = il ? 0x00F0 : 0x000F;
  96. const ushort mask1 = mask0 << 8;
  97. float4x4 reg_f;
  98. for (int i = 0; i < 8; i++) {
  99. reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md;
  100. reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md;
  101. }
  102. reg = (type4x4) reg_f;
  103. }
  104. template <typename type4>
  105. void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & reg) {
  106. device const uint16_t * qs = ((device const uint16_t *)xb + 1);
  107. const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
  108. const float d2 = d1 / 256.f;
  109. const float md = -8.h * xb->d;
  110. const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
  111. const ushort mask1 = mask0 << 8;
  112. for (int i = 0; i < 2; i++) {
  113. reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + md;
  114. reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + md;
  115. }
  116. }
  117. void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
  118. #pragma METAL fp math_mode(safe)
  119. float amax = 0.0f; // absolute max
  120. float max = 0.0f;
  121. for (int j = 0; j < QK4_0; j++) {
  122. const float v = src[j];
  123. if (amax < fabs(v)) {
  124. amax = fabs(v);
  125. max = v;
  126. }
  127. }
  128. const float d = max / -8;
  129. const float id = d ? 1.0f/d : 0.0f;
  130. dst.d = d;
  131. for (int j = 0; j < QK4_0/2; ++j) {
  132. const float x0 = src[0 + j]*id;
  133. const float x1 = src[QK4_0/2 + j]*id;
  134. const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
  135. const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
  136. dst.qs[j] = xi0;
  137. dst.qs[j] |= xi1 << 4;
  138. }
  139. }
  140. void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
  141. #pragma METAL fp math_mode(safe)
  142. float min = FLT_MAX;
  143. float max = -FLT_MAX;
  144. for (int j = 0; j < QK4_1; j++) {
  145. const float v = src[j];
  146. if (min > v) min = v;
  147. if (max < v) max = v;
  148. }
  149. const float d = (max - min) / ((1 << 4) - 1);
  150. const float id = d ? 1.0f/d : 0.0f;
  151. dst.d = d;
  152. dst.m = min;
  153. for (int j = 0; j < QK4_1/2; ++j) {
  154. const float x0 = (src[0 + j] - min)*id;
  155. const float x1 = (src[QK4_1/2 + j] - min)*id;
  156. const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
  157. const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
  158. dst.qs[j] = xi0;
  159. dst.qs[j] |= xi1 << 4;
  160. }
  161. }
  162. void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
  163. #pragma METAL fp math_mode(safe)
  164. float amax = 0.0f; // absolute max
  165. float max = 0.0f;
  166. for (int j = 0; j < QK5_0; j++) {
  167. const float v = src[j];
  168. if (amax < fabs(v)) {
  169. amax = fabs(v);
  170. max = v;
  171. }
  172. }
  173. const float d = max / -16;
  174. const float id = d ? 1.0f/d : 0.0f;
  175. dst.d = d;
  176. uint32_t qh = 0;
  177. for (int j = 0; j < QK5_0/2; ++j) {
  178. const float x0 = src[0 + j]*id;
  179. const float x1 = src[QK5_0/2 + j]*id;
  180. const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
  181. const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
  182. dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
  183. qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
  184. qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
  185. }
  186. thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
  187. for (int j = 0; j < 4; ++j) {
  188. dst.qh[j] = qh8[j];
  189. }
  190. }
  191. void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
  192. #pragma METAL fp math_mode(safe)
  193. float max = src[0];
  194. float min = src[0];
  195. for (int j = 1; j < QK5_1; j++) {
  196. const float v = src[j];
  197. min = v < min ? v : min;
  198. max = v > max ? v : max;
  199. }
  200. const float d = (max - min) / 31;
  201. const float id = d ? 1.0f/d : 0.0f;
  202. dst.d = d;
  203. dst.m = min;
  204. uint32_t qh = 0;
  205. for (int j = 0; j < QK5_1/2; ++j) {
  206. const float x0 = (src[0 + j] - min)*id;
  207. const float x1 = (src[QK5_1/2 + j] - min)*id;
  208. const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
  209. const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
  210. dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
  211. qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
  212. qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
  213. }
  214. thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
  215. for (int j = 0; j < 4; ++j) {
  216. dst.qh[j] = qh8[j];
  217. }
  218. }
  219. void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
  220. #pragma METAL fp math_mode(safe)
  221. float amax = 0.0f; // absolute max
  222. for (int j = 0; j < QK8_0; j++) {
  223. const float v = src[j];
  224. amax = MAX(amax, fabs(v));
  225. }
  226. const float d = amax / ((1 << 7) - 1);
  227. const float id = d ? 1.0f/d : 0.0f;
  228. dst.d = d;
  229. for (int j = 0; j < QK8_0; ++j) {
  230. const float x0 = src[j]*id;
  231. dst.qs[j] = round(x0);
  232. }
  233. }
  234. void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
  235. #pragma METAL fp math_mode(safe)
  236. float amax = 0.0f; // absolute max
  237. float max = 0.0f;
  238. for (int j = 0; j < QK4_NL; j++) {
  239. const float v = src[j];
  240. if (amax < fabs(v)) {
  241. amax = fabs(v);
  242. max = v;
  243. }
  244. }
  245. const float d = max / kvalues_iq4nl_f[0];
  246. const float id = d ? 1.0f/d : 0.0f;
  247. float sumqx = 0, sumq2 = 0;
  248. for (int j = 0; j < QK4_NL/2; ++j) {
  249. const float x0 = src[0 + j]*id;
  250. const float x1 = src[QK4_NL/2 + j]*id;
  251. const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
  252. const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
  253. dst.qs[j] = xi0 | (xi1 << 4);
  254. const float v0 = kvalues_iq4nl_f[xi0];
  255. const float v1 = kvalues_iq4nl_f[xi1];
  256. const float w0 = src[0 + j]*src[0 + j];
  257. const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
  258. sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
  259. sumq2 += w0*v0*v0 + w1*v1*v1;
  260. }
  261. dst.d = sumq2 > 0 ? sumqx/sumq2 : d;
  262. }
  263. template <typename type4x4>
  264. void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
  265. device const uint16_t * qs = ((device const uint16_t *)xb + 2);
  266. const float d1 = il ? (xb->d / 16.h) : xb->d;
  267. const float d2 = d1 / 256.f;
  268. const float m = xb->m;
  269. const ushort mask0 = il ? 0x00F0 : 0x000F;
  270. const ushort mask1 = mask0 << 8;
  271. float4x4 reg_f;
  272. for (int i = 0; i < 8; i++) {
  273. reg_f[i/2][2*(i%2) + 0] = ((qs[i] & mask0) * d1) + m;
  274. reg_f[i/2][2*(i%2) + 1] = ((qs[i] & mask1) * d2) + m;
  275. }
  276. reg = (type4x4) reg_f;
  277. }
  278. template <typename type4>
  279. void dequantize_q4_1_t4(device const block_q4_1 * xb, short il, thread type4 & reg) {
  280. device const uint16_t * qs = ((device const uint16_t *)xb + 2);
  281. const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
  282. const float d2 = d1 / 256.f;
  283. const float m = xb->m;
  284. const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
  285. const ushort mask1 = mask0 << 8;
  286. for (int i = 0; i < 2; i++) {
  287. reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + m;
  288. reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + m;
  289. }
  290. }
  291. template <typename type4x4>
  292. void dequantize_q5_0(device const block_q5_0 * xb, short il, thread type4x4 & reg) {
  293. device const uint16_t * qs = ((device const uint16_t *)xb + 3);
  294. const float d = xb->d;
  295. const float md = -16.h * xb->d;
  296. const ushort mask = il ? 0x00F0 : 0x000F;
  297. const uint32_t qh = *((device const uint32_t *)xb->qh);
  298. const int x_mv = il ? 4 : 0;
  299. const int gh_mv = il ? 12 : 0;
  300. const int gh_bk = il ? 0 : 4;
  301. float4x4 reg_f;
  302. for (int i = 0; i < 8; i++) {
  303. // extract the 5-th bits for x0 and x1
  304. const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
  305. const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
  306. // combine the 4-bits from qs with the 5th bit
  307. const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
  308. const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
  309. reg_f[i/2][2*(i%2) + 0] = d * x0 + md;
  310. reg_f[i/2][2*(i%2) + 1] = d * x1 + md;
  311. }
  312. reg = (type4x4) reg_f;
  313. }
  314. template <typename type4>
  315. void dequantize_q5_0_t4(device const block_q5_0 * xb, short il, thread type4 & reg) {
  316. device const uint16_t * qs = ((device const uint16_t *)xb + 3);
  317. const float d = xb->d;
  318. const float md = -16.h * xb->d;
  319. const ushort mask = (il/4) ? 0x00F0 : 0x000F;
  320. const uint32_t qh = *((device const uint32_t *)xb->qh);
  321. const int x_mv = (il/4) ? 4 : 0;
  322. const int gh_mv = (il/4) ? 12 : 0;
  323. const int gh_bk = (il/4) ? 0 : 4;
  324. for (int ii = 0; ii < 2; ii++) {
  325. int i = 2*(il%4) + ii;
  326. // extract the 5-th bits for x0 and x1
  327. const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
  328. const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
  329. // combine the 4-bits from qs with the 5th bit
  330. const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
  331. const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
  332. reg[2*ii + 0] = d * x0 + md;
  333. reg[2*ii + 1] = d * x1 + md;
  334. }
  335. }
  336. template <typename type4x4>
  337. void dequantize_q5_1(device const block_q5_1 * xb, short il, thread type4x4 & reg) {
  338. device const uint16_t * qs = ((device const uint16_t *)xb + 4);
  339. const float d = xb->d;
  340. const float m = xb->m;
  341. const ushort mask = il ? 0x00F0 : 0x000F;
  342. const uint32_t qh = *((device const uint32_t *)xb->qh);
  343. const int x_mv = il ? 4 : 0;
  344. const int gh_mv = il ? 12 : 0;
  345. const int gh_bk = il ? 0 : 4;
  346. float4x4 reg_f;
  347. for (int i = 0; i < 8; i++) {
  348. // extract the 5-th bits for x0 and x1
  349. const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
  350. const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
  351. // combine the 4-bits from qs with the 5th bit
  352. const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
  353. const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
  354. reg_f[i/2][2*(i%2) + 0] = d * x0 + m;
  355. reg_f[i/2][2*(i%2) + 1] = d * x1 + m;
  356. }
  357. reg = (type4x4) reg_f;
  358. }
  359. template <typename type4>
  360. void dequantize_q5_1_t4(device const block_q5_1 * xb, short il, thread type4 & reg) {
  361. device const uint16_t * qs = ((device const uint16_t *)xb + 4);
  362. const float d = xb->d;
  363. const float m = xb->m;
  364. const ushort mask = (il/4) ? 0x00F0 : 0x000F;
  365. const uint32_t qh = *((device const uint32_t *)xb->qh);
  366. const int x_mv = (il/4) ? 4 : 0;
  367. const int gh_mv = (il/4) ? 12 : 0;
  368. const int gh_bk = (il/4) ? 0 : 4;
  369. for (int ii = 0; ii < 2; ii++) {
  370. int i = 2*(il%4) + ii;
  371. // extract the 5-th bits for x0 and x1
  372. const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
  373. const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
  374. // combine the 4-bits from qs with the 5th bit
  375. const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
  376. const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
  377. reg[2*ii + 0] = d * x0 + m;
  378. reg[2*ii + 1] = d * x1 + m;
  379. }
  380. }
  381. template <typename type4x4>
  382. void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
  383. device const int8_t * qs = ((device const int8_t *)xb->qs);
  384. const float d = xb->d;
  385. float4x4 reg_f;
  386. for (int i = 0; i < 16; i++) {
  387. reg_f[i/4][i%4] = (qs[i + 16*il] * d);
  388. }
  389. reg = (type4x4) reg_f;
  390. }
  391. template <typename type4>
  392. void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & reg) {
  393. device const int8_t * qs = ((device const int8_t *)xb->qs);
  394. const float d = xb->d;
  395. for (int i = 0; i < 4; i++) {
  396. reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);
  397. }
  398. }
  399. template <typename type4x4>
  400. void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) {
  401. device const uint8_t * q2 = (device const uint8_t *)xb->qs;
  402. const float d = e8m0_to_fp32(xb->e);
  403. const uint8_t shr = il >= 1 ? 4 : 0;
  404. for (int i = 0; i < 4; ++i) {
  405. reg[i][0] = d * kvalues_mxfp4_f[(q2[4*i + 0] >> shr) & 0x0F];
  406. reg[i][1] = d * kvalues_mxfp4_f[(q2[4*i + 1] >> shr) & 0x0F];
  407. reg[i][2] = d * kvalues_mxfp4_f[(q2[4*i + 2] >> shr) & 0x0F];
  408. reg[i][3] = d * kvalues_mxfp4_f[(q2[4*i + 3] >> shr) & 0x0F];
  409. }
  410. }
  411. template <typename type4>
  412. void dequantize_mxfp4_t4(device const block_mxfp4 * xb, short il, thread type4 & reg) {
  413. device const uint8_t * q2 = (device const uint8_t *)xb->qs;
  414. const float d = e8m0_to_fp32(xb->e);
  415. const short il4 = il%4;
  416. const uint8_t shr = il >= 4 ? 4 : 0;
  417. reg[0] = d * kvalues_mxfp4_f[(q2[4*il4 + 0] >> shr) & 0x0F];
  418. reg[1] = d * kvalues_mxfp4_f[(q2[4*il4 + 1] >> shr) & 0x0F];
  419. reg[2] = d * kvalues_mxfp4_f[(q2[4*il4 + 2] >> shr) & 0x0F];
  420. reg[3] = d * kvalues_mxfp4_f[(q2[4*il4 + 3] >> shr) & 0x0F];
  421. }
  422. template <typename type4x4>
  423. void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
  424. const float d = xb->d;
  425. const float min = xb->dmin;
  426. device const uint8_t * q = (device const uint8_t *)xb->qs;
  427. float dl, ml;
  428. uint8_t sc = xb->scales[il];
  429. q = q + 32*(il/8) + 16*(il&1);
  430. il = (il/2)%4;
  431. half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
  432. uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
  433. dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
  434. for (int i = 0; i < 16; ++i) {
  435. reg[i/4][i%4] = dl * (q[i] & mask) - ml;
  436. }
  437. }
  438. template <typename type4x4>
  439. void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
  440. const half d_all = xb->d;
  441. device const uint8_t * q = (device const uint8_t *)xb->qs;
  442. device const uint8_t * h = (device const uint8_t *)xb->hmask;
  443. device const int8_t * scales = (device const int8_t *)xb->scales;
  444. q = q + 32 * (il/8) + 16 * (il&1);
  445. h = h + 16 * (il&1);
  446. uint8_t m = 1 << (il/2);
  447. uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
  448. ((il/4)>0 ? 12 : 3);
  449. uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
  450. uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
  451. int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
  452. : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
  453. float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
  454. const float ml = 4.f * dl;
  455. il = (il/2) & 3;
  456. const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
  457. const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
  458. dl *= coef;
  459. for (int i = 0; i < 16; ++i) {
  460. reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
  461. }
  462. }
  463. static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
  464. return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
  465. : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
  466. }
  467. template <typename type4x4>
  468. void dequantize_q4_K(device const block_q4_K * xb, short il, thread type4x4 & reg) {
  469. device const uchar * q = xb->qs;
  470. short is = (il/4) * 2;
  471. q = q + (il/4) * 32 + 16 * (il&1);
  472. il = il & 3;
  473. const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
  474. const float d = il < 2 ? xb->d : xb->d / 16.h;
  475. const float min = xb->dmin;
  476. const float dl = d * sc[0];
  477. const float ml = min * sc[1];
  478. const ushort mask = il < 2 ? 0x0F : 0xF0;
  479. for (int i = 0; i < 16; ++i) {
  480. reg[i/4][i%4] = dl * (q[i] & mask) - ml;
  481. }
  482. }
  483. template <typename type4x4>
  484. void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
  485. device const uint8_t * q = xb->qs;
  486. device const uint8_t * qh = xb->qh;
  487. short is = (il/4) * 2;
  488. q = q + 32 * (il/4) + 16 * (il&1);
  489. qh = qh + 16 * (il&1);
  490. uint8_t ul = 1 << (il/2);
  491. il = il & 3;
  492. const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
  493. const float d = il < 2 ? xb->d : xb->d / 16.f;
  494. const float min = xb->dmin;
  495. const float dl = d * sc[0];
  496. const float ml = min * sc[1];
  497. const ushort mask = il<2 ? 0x0F : 0xF0;
  498. const float qh_val = il<2 ? 16.f : 256.f;
  499. for (int i = 0; i < 16; ++i) {
  500. reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
  501. }
  502. }
  503. template <typename type4x4>
  504. void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
  505. const half d_all = xb->d;
  506. device const uint16_t * ql = (device const uint16_t *)xb->ql;
  507. device const uint16_t * qh = (device const uint16_t *)xb->qh;
  508. device const int8_t * scales = (device const int8_t *)xb->scales;
  509. ql = ql + 32*(il/8) + 16*((il/2)&1) + 8*(il&1);
  510. qh = qh + 16*(il/8) + 8*(il&1);
  511. float sc = scales[(il%2) + 2 * ((il/2))];
  512. il = (il/2) & 3;
  513. const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030) : (il>0 ? 0x0C0C0C0C : 0x03030303);
  514. const uint32_t kmask2 = il>1 ? 0xF0F0F0F0 : 0x0F0F0F0F;
  515. const float ml = d_all * sc * 32.f;
  516. const float dl0 = d_all * sc;
  517. const float dl1 = dl0 / 256.f;
  518. const float dl2 = dl0 / (256.f * 256.f);
  519. const float dl3 = dl0 / (256.f * 256.f * 256.f);
  520. const uint8_t shr_h = il>2 ? 2 : 0;
  521. const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4);
  522. const uint8_t shr_l = il>1 ? 4 : 0;
  523. for (int i = 0; i < 4; ++i) {
  524. const uint32_t low = (ql[2*i] | (uint32_t)(ql[2*i+1] << 16)) & kmask2;
  525. const uint32_t high = (qh[2*i] | (uint32_t)(qh[2*i+1] << 16)) & kmask1;
  526. const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l);
  527. reg[i][0] = dl0 * ((half)(q & 0xFF)) - ml;
  528. reg[i][1] = dl1 * ((float)(q & 0xFF00)) - ml;
  529. reg[i][2] = dl2 * ((float)(q & 0xFF0000)) - ml;
  530. reg[i][3] = dl3 * ((float)(q & 0xFF000000)) - ml;
  531. }
  532. }
  533. template <typename type4x4>
  534. void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
  535. // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
  536. const float d = xb->d;
  537. const int ib32 = il/2;
  538. il = il%2;
  539. // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
  540. // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
  541. device const uint16_t * q2 = xb->qs + 4*ib32;
  542. const uint32_t aux32_g = q2[0] | (q2[1] << 16);
  543. const uint32_t aux32_s = q2[2] | (q2[3] << 16);
  544. thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
  545. const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
  546. constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
  547. uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
  548. for (int i = 0; i < 8; ++i) {
  549. reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
  550. }
  551. grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
  552. signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
  553. for (int i = 0; i < 8; ++i) {
  554. reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
  555. }
  556. }
  557. template <typename type4x4>
  558. void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
  559. // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
  560. const float d = xb->d;
  561. const int ib32 = il/2;
  562. il = il%2;
  563. // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
  564. device const uint16_t * q2 = xb->qs + 4*ib32;
  565. const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
  566. constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
  567. uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
  568. for (int i = 0; i < 8; ++i) {
  569. reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
  570. }
  571. grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
  572. signs = ksigns_iq2xs[q2[2*il+1] >> 9];
  573. for (int i = 0; i < 8; ++i) {
  574. reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
  575. }
  576. }
  577. template <typename type4x4>
  578. void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
  579. // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
  580. const float d = xb->d;
  581. const int ib32 = il/2;
  582. il = il%2;
  583. // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
  584. device const uint8_t * q3 = xb->qs + 8*ib32;
  585. device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
  586. const uint32_t aux32 = gas[0] | (gas[1] << 16);
  587. const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
  588. constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
  589. constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
  590. uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
  591. for (int i = 0; i < 4; ++i) {
  592. reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
  593. reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
  594. }
  595. grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
  596. grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
  597. signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
  598. for (int i = 0; i < 4; ++i) {
  599. reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
  600. reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
  601. }
  602. }
  603. template <typename type4x4>
  604. void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
  605. // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
  606. const float d = xb->d;
  607. const int ib32 = il/2;
  608. il = il%2;
  609. // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
  610. device const uint8_t * qs = xb->qs + 8*ib32;
  611. device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
  612. const uint8_t qh = xb->qh[ib32] >> 4*il;
  613. const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
  614. constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
  615. constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
  616. for (int i = 0; i < 4; ++i) {
  617. reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
  618. reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
  619. }
  620. grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
  621. grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
  622. for (int i = 0; i < 4; ++i) {
  623. reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
  624. reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
  625. }
  626. }
  627. template <typename type4x4>
  628. void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {
  629. // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
  630. const float d = xb->d;
  631. const int ib32 = il/2;
  632. il = il%2;
  633. // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
  634. device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
  635. device const uint8_t * signs = qs + QK_K/8;
  636. const uint8_t qh = xb->qh[ib32] >> 4*il;
  637. const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
  638. constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));
  639. constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));
  640. for (int i = 0; i < 8; ++i) {
  641. reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);
  642. reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);
  643. }
  644. }
  645. template <typename type4x4>
  646. void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
  647. // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
  648. const int ib32 = il/2;
  649. il = il%2;
  650. const float d = xb->d;
  651. device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
  652. device const uint16_t * qh = xb->qh;
  653. const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1);
  654. const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA);
  655. const uint16_t h = qh[ib32] >> 6*il;
  656. constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700)));
  657. constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700)));
  658. for (int i = 0; i < 4; ++i) {
  659. reg[0][i] = dl * (grid1[i] & 0xf) + ml;
  660. reg[1][i] = dl * (grid1[i] >> 4) + ml;
  661. reg[2][i] = dl * (grid2[i] & 0xf) + ml;
  662. reg[3][i] = dl * (grid2[i] >> 4) + ml;
  663. }
  664. }
  665. template <typename type4x4>
  666. void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
  667. // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
  668. const int ib32 = il/2;
  669. il = il%2;
  670. device const uint16_t * sc = (device const uint16_t *)xb->scales;
  671. iq1m_scale_t scale;
  672. scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
  673. const float d = scale.f16;
  674. device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
  675. device const uint8_t * qh = xb->qh + 2*ib32 + il;
  676. const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
  677. const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
  678. const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
  679. constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
  680. constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
  681. for (int i = 0; i < 4; ++i) {
  682. reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
  683. reg[1][i] = dl * (grid1[i] >> 4) + ml1;
  684. reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
  685. reg[3][i] = dl * (grid2[i] >> 4) + ml2;
  686. }
  687. }
  688. template <typename type4x4>
  689. void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
  690. device const uint16_t * q4 = (device const uint16_t *)xb->qs;
  691. const float d = xb->d;
  692. uint32_t aux32;
  693. thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
  694. for (int i = 0; i < 4; ++i) {
  695. aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
  696. reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
  697. reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
  698. reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
  699. reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
  700. }
  701. }
  702. template <typename type4>
  703. void dequantize_iq4_nl_t4(device const block_iq4_nl * xb, short il, thread type4 & reg) {
  704. device const uint16_t * q4 = (device const uint16_t *)xb->qs;
  705. const float d = xb->d;
  706. uint32_t aux32;
  707. thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
  708. aux32 = ((q4[2*(il%4)] | (q4[2*(il%4)+1] << 16)) >> 4*(il/4)) & 0x0f0f0f0f;
  709. reg[0] = d * kvalues_iq4nl_f[q8[0]];
  710. reg[1] = d * kvalues_iq4nl_f[q8[1]];
  711. reg[2] = d * kvalues_iq4nl_f[q8[2]];
  712. reg[3] = d * kvalues_iq4nl_f[q8[3]];
  713. }
  714. template <typename type4x4>
  715. void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
  716. // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
  717. const int ib32 = il/2;
  718. il = il%2;
  719. // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
  720. device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
  721. const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
  722. const float d = (float)xb->d * (ls - 32);
  723. uint32_t aux32;
  724. thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
  725. for (int i = 0; i < 4; ++i) {
  726. aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
  727. reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
  728. reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
  729. reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
  730. reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
  731. }
  732. }
  733. enum ggml_sort_order {
  734. GGML_SORT_ORDER_ASC,
  735. GGML_SORT_ORDER_DESC,
  736. };
  737. // general-purpose kernel for addition, subtraction, multiplication and division of two tensors
  738. // pros: works for non-contiguous tensors, supports broadcast across all dims
  739. // cons: not very efficient
  740. template <int F>
  741. kernel void kernel_add_fuse_impl(
  742. constant ggml_metal_kargs_bin & args,
  743. device const char * src0,
  744. device const char * src1,
  745. device char * dst,
  746. uint3 tgpig[[threadgroup_position_in_grid]],
  747. ushort3 tpitg[[thread_position_in_threadgroup]],
  748. ushort3 ntg[[threads_per_threadgroup]]) {
  749. const int i03 = tgpig.z;
  750. const int i02 = tgpig.y;
  751. const int i01 = tgpig.x;
  752. const int i13 = i03%args.ne13;
  753. const int i12 = i02%args.ne12;
  754. const int i11 = i01%args.ne11;
  755. device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
  756. device float * dst_ptr = (device float *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
  757. device const float * src1_ptr[F];
  758. for (short j = 0; j < F; ++j) {
  759. src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
  760. }
  761. for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
  762. const int i10 = i0%args.ne10;
  763. float res = src0_ptr[i0];
  764. #pragma unroll
  765. for (short j = 0; j < F; ++j) {
  766. res += src1_ptr[j][i10];
  767. }
  768. dst_ptr[i0] = res;
  769. }
  770. }
  771. typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t;
  772. template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>;
  773. template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>;
  774. template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>;
  775. template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>;
  776. template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>;
  777. template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>;
  778. template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>;
  779. template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>;
  780. kernel void kernel_sub_fuse_1(
  781. constant ggml_metal_kargs_bin & args,
  782. device const char * src0,
  783. device const char * src1,
  784. device char * dst,
  785. uint3 tgpig[[threadgroup_position_in_grid]],
  786. ushort3 tpitg[[thread_position_in_threadgroup]],
  787. ushort3 ntg[[threads_per_threadgroup]]) {
  788. const int i03 = tgpig.z;
  789. const int i02 = tgpig.y;
  790. const int i01 = tgpig.x;
  791. const int i13 = i03%args.ne13;
  792. const int i12 = i02%args.ne12;
  793. const int i11 = i01%args.ne11;
  794. device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
  795. device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
  796. device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
  797. for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
  798. const int i10 = i0%args.ne10;
  799. *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10));
  800. }
  801. }
  802. kernel void kernel_mul_fuse_1(
  803. constant ggml_metal_kargs_bin & args,
  804. device const char * src0,
  805. device const char * src1,
  806. device char * dst,
  807. uint3 tgpig[[threadgroup_position_in_grid]],
  808. ushort3 tpitg[[thread_position_in_threadgroup]],
  809. ushort3 ntg[[threads_per_threadgroup]]) {
  810. const int i03 = tgpig.z;
  811. const int i02 = tgpig.y;
  812. const int i01 = tgpig.x;
  813. const int i13 = i03%args.ne13;
  814. const int i12 = i02%args.ne12;
  815. const int i11 = i01%args.ne11;
  816. device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
  817. device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
  818. device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
  819. if (args.ne10 == 1) {
  820. const float x = *((device float *)(src1_ptr));
  821. for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
  822. *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
  823. }
  824. } else {
  825. for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
  826. const int i10 = i0%args.ne10;
  827. *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
  828. }
  829. }
  830. }
  831. kernel void kernel_div_fuse_1(
  832. constant ggml_metal_kargs_bin & args,
  833. device const char * src0,
  834. device const char * src1,
  835. device char * dst,
  836. uint3 tgpig[[threadgroup_position_in_grid]],
  837. ushort3 tpitg[[thread_position_in_threadgroup]],
  838. ushort3 ntg[[threads_per_threadgroup]]) {
  839. const int i03 = tgpig.z;
  840. const int i02 = tgpig.y;
  841. const int i01 = tgpig.x;
  842. const int i13 = i03%args.ne13;
  843. const int i12 = i02%args.ne12;
  844. const int i11 = i01%args.ne11;
  845. device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
  846. device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
  847. device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
  848. if (args.ne10 == 1) {
  849. const float x = 1.0f / *((device float *)(src1_ptr));
  850. for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
  851. *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
  852. }
  853. } else {
  854. for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
  855. const int i10 = i0%args.ne10;
  856. *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
  857. }
  858. }
  859. }
  860. kernel void kernel_add_id(
  861. constant ggml_metal_kargs_add_id & args,
  862. device const char * src0,
  863. device const char * src1,
  864. device const char * src2,
  865. device char * dst,
  866. uint3 tgpig[[threadgroup_position_in_grid]],
  867. ushort3 tpitg[[thread_position_in_threadgroup]],
  868. ushort3 ntg[[threads_per_threadgroup]]) {
  869. const int i1 = tgpig.x;
  870. const int i2 = tgpig.y;
  871. const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21));
  872. const size_t nb1 = args.ne0 * sizeof(float);
  873. const size_t nb2 = args.ne1 * nb1;
  874. device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2);
  875. device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02);
  876. device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
  877. for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
  878. dst_row[i0] = src0_row[i0] + src1_row[i0];
  879. }
  880. }
  881. template<typename T>
  882. kernel void kernel_repeat(
  883. constant ggml_metal_kargs_repeat & args,
  884. device const char * src0,
  885. device char * dst,
  886. uint3 tgpig[[threadgroup_position_in_grid]],
  887. ushort3 tpitg[[thread_position_in_threadgroup]],
  888. ushort3 ntg[[threads_per_threadgroup]]) {
  889. const int i3 = tgpig.z;
  890. const int i2 = tgpig.y;
  891. const int i1 = tgpig.x;
  892. const int i03 = i3%args.ne03;
  893. const int i02 = i2%args.ne02;
  894. const int i01 = i1%args.ne01;
  895. device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
  896. device char * dst_ptr = dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1;
  897. for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
  898. const int i00 = i0%args.ne00;
  899. *((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00));
  900. }
  901. }
  902. typedef decltype(kernel_repeat<float>) kernel_repeat_t;
  903. template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
  904. template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
  905. template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
  906. template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
  907. // assumption: src1 is a row
  908. // broadcast src1 into src0
  909. template <short F>
  910. kernel void kernel_add_row_c4_fuse_impl(
  911. constant ggml_metal_kargs_bin & args,
  912. device const char * src0,
  913. device const char * src1,
  914. device char * dst,
  915. uint tpig[[thread_position_in_grid]]) {
  916. const uint nb = args.ne00/4;
  917. const uint i = tpig % nb;
  918. device const float4 * src0_row = (device const float4 *) (src0);
  919. device float4 * dst_row = (device float4 *) (dst);
  920. float4 res = src0_row[tpig];
  921. #pragma unroll(F)
  922. for (short j = 0; j < F; ++j) {
  923. res += ((device const float4 *) (src1 + args.o1[j]))[i];
  924. }
  925. dst_row[tpig] = res;
  926. }
  927. typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t;
  928. template [[host_name("kernel_add_row_c4_fuse_1")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>;
  929. template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>;
  930. template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>;
  931. template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>;
  932. template [[host_name("kernel_add_row_c4_fuse_5")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<5>;
  933. template [[host_name("kernel_add_row_c4_fuse_6")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<6>;
  934. template [[host_name("kernel_add_row_c4_fuse_7")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<7>;
  935. template [[host_name("kernel_add_row_c4_fuse_8")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<8>;
  936. template <short F>
  937. kernel void kernel_sub_row_c4_fuse_impl(
  938. constant ggml_metal_kargs_bin & args,
  939. device const char * src0,
  940. device const char * src1,
  941. device char * dst,
  942. uint tpig[[thread_position_in_grid]]) {
  943. const uint nb = args.ne00/4;
  944. const uint i = tpig % nb;
  945. device const float4 * src0_row = (device const float4 *) (src0);
  946. device float4 * dst_row = (device float4 *) (dst);
  947. device const float4 * src1_row[F];
  948. for (short j = 0; j < F; ++j) {
  949. src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
  950. }
  951. float4 res = src0_row[tpig];
  952. #pragma unroll(F)
  953. for (short j = 0; j < F; ++j) {
  954. res -= src1_row[j][i];
  955. }
  956. dst_row[tpig] = res;
  957. }
  958. typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t;
  959. template [[host_name("kernel_sub_row_c4_fuse_1")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>;
  960. template <short F>
  961. kernel void kernel_mul_row_c4_fuse_impl(
  962. constant ggml_metal_kargs_bin & args,
  963. device const char * src0,
  964. device const char * src1,
  965. device char * dst,
  966. uint tpig[[thread_position_in_grid]]) {
  967. const uint nb = args.ne00/4;
  968. const uint i = tpig % nb;
  969. device const float4 * src0_row = (device const float4 *) (src0);
  970. device float4 * dst_row = (device float4 *) (dst);
  971. device const float4 * src1_row[F];
  972. for (short j = 0; j < F; ++j) {
  973. src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
  974. }
  975. float4 res = src0_row[tpig];
  976. #pragma unroll(F)
  977. for (short j = 0; j < F; ++j) {
  978. res *= src1_row[j][i];
  979. }
  980. dst_row[tpig] = res;
  981. }
  982. typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t;
  983. template [[host_name("kernel_mul_row_c4_fuse_1")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>;
  984. template <short F>
  985. kernel void kernel_div_row_c4_fuse_impl(
  986. constant ggml_metal_kargs_bin & args,
  987. device const char * src0,
  988. device const char * src1,
  989. device char * dst,
  990. uint tpig[[thread_position_in_grid]]) {
  991. const uint nb = args.ne00/4;
  992. const uint i = tpig % nb;
  993. device const float4 * src0_row = (device const float4 *) (src0);
  994. device float4 * dst_row = (device float4 *) (dst);
  995. device const float4 * src1_row[F];
  996. for (short j = 0; j < F; ++j) {
  997. src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
  998. }
  999. float4 res = src0_row[tpig];
  1000. #pragma unroll(F)
  1001. for (short j = 0; j < F; ++j) {
  1002. res /= src1_row[j][i];
  1003. }
  1004. dst_row[tpig] = res;
  1005. }
  1006. typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t;
  1007. template [[host_name("kernel_div_row_c4_fuse_1")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>;
  1008. kernel void kernel_scale_f32(
  1009. constant ggml_metal_kargs_scale & args,
  1010. device const float * src0,
  1011. device float * dst,
  1012. uint tpig[[thread_position_in_grid]]) {
  1013. dst[tpig] = src0[tpig] * args.scale + args.bias;
  1014. }
  1015. kernel void kernel_scale_f32_4(
  1016. constant ggml_metal_kargs_scale & args,
  1017. device const float4 * src0,
  1018. device float4 * dst,
  1019. uint tpig[[thread_position_in_grid]]) {
  1020. dst[tpig] = src0[tpig] * args.scale + args.bias;
  1021. }
  1022. kernel void kernel_clamp_f32(
  1023. constant ggml_metal_kargs_clamp & args,
  1024. device const float * src0,
  1025. device float * dst,
  1026. uint tpig[[thread_position_in_grid]]) {
  1027. dst[tpig] = clamp(src0[tpig], args.min, args.max);
  1028. }
  1029. kernel void kernel_clamp_f32_4(
  1030. constant ggml_metal_kargs_clamp & args,
  1031. device const float4 * src0,
  1032. device float4 * dst,
  1033. uint tpig[[thread_position_in_grid]]) {
  1034. dst[tpig] = clamp(src0[tpig], args.min, args.max);
  1035. }
  1036. kernel void kernel_relu_f32(
  1037. device const float * src0,
  1038. device float * dst,
  1039. uint tpig[[thread_position_in_grid]]) {
  1040. dst[tpig] = max(0.0f, src0[tpig]);
  1041. }
  1042. kernel void kernel_relu_f32_4(
  1043. device const float4 * src0,
  1044. device float4 * dst,
  1045. uint tpig[[thread_position_in_grid]]) {
  1046. dst[tpig] = max(0.0f, src0[tpig]);
  1047. }
  1048. kernel void kernel_sigmoid_f32(
  1049. device const float * src0,
  1050. device float * dst,
  1051. uint tpig[[thread_position_in_grid]]) {
  1052. dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
  1053. }
  1054. kernel void kernel_sigmoid_f32_4(
  1055. device const float4 * src0,
  1056. device float4 * dst,
  1057. uint tpig[[thread_position_in_grid]]) {
  1058. dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
  1059. }
  1060. kernel void kernel_tanh_f32(
  1061. device const float * src0,
  1062. device float * dst,
  1063. uint tpig[[thread_position_in_grid]]) {
  1064. dst[tpig] = precise::tanh(src0[tpig]);
  1065. }
  1066. kernel void kernel_tanh_f32_4(
  1067. device const float4 * src0,
  1068. device float4 * dst,
  1069. uint tpig[[thread_position_in_grid]]) {
  1070. dst[tpig] = precise::tanh(src0[tpig]);
  1071. }
  1072. constant float GELU_COEF_A = 0.044715f;
  1073. constant float GELU_QUICK_COEF = -1.702f;
  1074. constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
  1075. constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
  1076. kernel void kernel_gelu_f32(
  1077. device const float * src0,
  1078. device float * dst,
  1079. uint tpig[[thread_position_in_grid]]) {
  1080. device const float & x = src0[tpig];
  1081. dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
  1082. }
  1083. kernel void kernel_gelu_f32_4(
  1084. device const float4 * src0,
  1085. device float4 * dst,
  1086. uint tpig[[thread_position_in_grid]]) {
  1087. device const float4 & x = src0[tpig];
  1088. // BEWARE !!!
  1089. // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
  1090. // This was observed with Falcon 7B and 40B models
  1091. //
  1092. dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
  1093. }
  1094. kernel void kernel_gelu_quick_f32(
  1095. device const float * src0,
  1096. device float * dst,
  1097. uint tpig[[thread_position_in_grid]]) {
  1098. device const float & x = src0[tpig];
  1099. dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
  1100. }
  1101. kernel void kernel_gelu_quick_f32_4(
  1102. device const float4 * src0,
  1103. device float4 * dst,
  1104. uint tpig[[thread_position_in_grid]]) {
  1105. device const float4 & x = src0[tpig];
  1106. dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
  1107. }
  1108. // based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
  1109. // ref: https://www.johndcook.com/blog/python_erf/
  1110. constant float p_erf = 0.3275911f;
  1111. constant float a1_erf = 0.254829592f;
  1112. constant float a2_erf = -0.284496736f;
  1113. constant float a3_erf = 1.421413741f;
  1114. constant float a4_erf = -1.453152027f;
  1115. constant float a5_erf = 1.061405429f;
  1116. template<typename T>
  1117. T erf_approx(T x) {
  1118. T sign_x = sign(x);
  1119. x = fabs(x);
  1120. T t = 1.0f / (1.0f + p_erf * x);
  1121. T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
  1122. return sign_x * y;
  1123. }
  1124. kernel void kernel_gelu_erf_f32(
  1125. device const float * src0,
  1126. device float * dst,
  1127. uint tpig[[thread_position_in_grid]]) {
  1128. device const float & x = src0[tpig];
  1129. dst[tpig] = 0.5f*x*(1.0f+erf_approx<float>(x*SQRT_2_INV));
  1130. }
  1131. kernel void kernel_gelu_erf_f32_4(
  1132. device const float4 * src0,
  1133. device float4 * dst,
  1134. uint tpig[[thread_position_in_grid]]) {
  1135. device const float4 & x = src0[tpig];
  1136. dst[tpig] = 0.5f*x*(1.0f+erf_approx<float4>(x*SQRT_2_INV));
  1137. }
  1138. kernel void kernel_silu_f32(
  1139. device const float * src0,
  1140. device float * dst,
  1141. uint tpig[[thread_position_in_grid]]) {
  1142. device const float & x = src0[tpig];
  1143. dst[tpig] = x / (1.0f + exp(-x));
  1144. }
  1145. kernel void kernel_silu_f32_4(
  1146. device const float4 * src0,
  1147. device float4 * dst,
  1148. uint tpig[[thread_position_in_grid]]) {
  1149. device const float4 & x = src0[tpig];
  1150. dst[tpig] = x / (1.0f + exp(-x));
  1151. }
  1152. kernel void kernel_elu_f32(
  1153. device const float * src0,
  1154. device float * dst,
  1155. uint tpig[[thread_position_in_grid]]) {
  1156. const float x = src0[tpig];
  1157. dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f);
  1158. }
  1159. kernel void kernel_elu_f32_4(
  1160. device const float4 * src0,
  1161. device float4 * dst,
  1162. uint tpig[[thread_position_in_grid]]) {
  1163. const float4 x = src0[tpig];
  1164. dst[tpig][0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
  1165. dst[tpig][1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
  1166. dst[tpig][2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
  1167. dst[tpig][3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
  1168. }
  1169. kernel void kernel_sqr_f32(
  1170. device const float * src0,
  1171. device float * dst,
  1172. uint tpig[[thread_position_in_grid]]) {
  1173. dst[tpig] = src0[tpig] * src0[tpig];
  1174. }
  1175. kernel void kernel_sqr_f32_4(
  1176. device const float4 * src0,
  1177. device float4 * dst,
  1178. uint tpig[[thread_position_in_grid]]) {
  1179. dst[tpig] = src0[tpig] * src0[tpig];
  1180. }
  1181. kernel void kernel_sqrt_f32(
  1182. device const float * src0,
  1183. device float * dst,
  1184. uint tpig[[thread_position_in_grid]]) {
  1185. dst[tpig] = sqrt(src0[tpig]);
  1186. }
  1187. kernel void kernel_sqrt_f32_4(
  1188. device const float4 * src0,
  1189. device float4 * dst,
  1190. uint tpig[[thread_position_in_grid]]) {
  1191. dst[tpig] = sqrt(src0[tpig]);
  1192. }
  1193. kernel void kernel_sin_f32(
  1194. device const float * src0,
  1195. device float * dst,
  1196. uint tpig[[thread_position_in_grid]]) {
  1197. dst[tpig] = sin(src0[tpig]);
  1198. }
  1199. kernel void kernel_sin_f32_4(
  1200. device const float4 * src0,
  1201. device float4 * dst,
  1202. uint tpig[[thread_position_in_grid]]) {
  1203. dst[tpig] = sin(src0[tpig]);
  1204. }
  1205. kernel void kernel_cos_f32(
  1206. device const float * src0,
  1207. device float * dst,
  1208. uint tpig[[thread_position_in_grid]]) {
  1209. dst[tpig] = cos(src0[tpig]);
  1210. }
  1211. kernel void kernel_cos_f32_4(
  1212. device const float4 * src0,
  1213. device float4 * dst,
  1214. uint tpig[[thread_position_in_grid]]) {
  1215. dst[tpig] = cos(src0[tpig]);
  1216. }
  1217. kernel void kernel_log_f32(
  1218. device const float * src0,
  1219. device float * dst,
  1220. uint tpig[[thread_position_in_grid]]) {
  1221. dst[tpig] = log(src0[tpig]);
  1222. }
  1223. kernel void kernel_log_f32_4(
  1224. device const float4 * src0,
  1225. device float4 * dst,
  1226. uint tpig[[thread_position_in_grid]]) {
  1227. dst[tpig] = log(src0[tpig]);
  1228. }
  1229. kernel void kernel_neg_f32(
  1230. device const float * src0,
  1231. device float * dst,
  1232. uint tpig[[thread_position_in_grid]]) {
  1233. dst[tpig] = -src0[tpig];
  1234. }
  1235. kernel void kernel_neg_f32_4(
  1236. device const float4 * src0,
  1237. device float4 * dst,
  1238. uint tpig[[thread_position_in_grid]]) {
  1239. dst[tpig] = -src0[tpig];
  1240. }
  1241. kernel void kernel_abs_f32(
  1242. device const float * src0,
  1243. device float * dst,
  1244. uint tpig[[thread_position_in_grid]]) {
  1245. dst[tpig] = fabs(src0[tpig]);
  1246. }
  1247. kernel void kernel_abs_f32_4(
  1248. device const float4 * src0,
  1249. device float4 * dst,
  1250. uint tpig[[thread_position_in_grid]]) {
  1251. dst[tpig] = fabs(src0[tpig]);
  1252. }
  1253. kernel void kernel_sgn_f32(
  1254. device const float * src0,
  1255. device float * dst,
  1256. uint tpig[[thread_position_in_grid]]) {
  1257. dst[tpig] = sign(src0[tpig]);
  1258. }
  1259. kernel void kernel_sgn_f32_4(
  1260. device const float4 * src0,
  1261. device float4 * dst,
  1262. uint tpig[[thread_position_in_grid]]) {
  1263. dst[tpig] = sign(src0[tpig]);
  1264. }
  1265. kernel void kernel_step_f32(
  1266. device const float * src0,
  1267. device float * dst,
  1268. uint tpig[[thread_position_in_grid]]) {
  1269. dst[tpig] = step(0.0f, src0[tpig]);
  1270. }
  1271. kernel void kernel_step_f32_4(
  1272. device const float4 * src0,
  1273. device float4 * dst,
  1274. uint tpig[[thread_position_in_grid]]) {
  1275. dst[tpig] = step(0.0f, src0[tpig]);
  1276. }
  1277. kernel void kernel_hardswish_f32(
  1278. device const float * src0,
  1279. device float * dst,
  1280. uint tpig[[thread_position_in_grid]]) {
  1281. const float x = src0[tpig];
  1282. dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
  1283. }
  1284. kernel void kernel_hardswish_f32_4(
  1285. device const float4 * src0,
  1286. device float4 * dst,
  1287. uint tpig[[thread_position_in_grid]]) {
  1288. const float4 x = src0[tpig];
  1289. dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
  1290. }
  1291. kernel void kernel_hardsigmoid_f32(
  1292. device const float * src0,
  1293. device float * dst,
  1294. uint tpig[[thread_position_in_grid]]) {
  1295. const float x = src0[tpig];
  1296. dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
  1297. }
  1298. kernel void kernel_hardsigmoid_f32_4(
  1299. device const float4 * src0,
  1300. device float4 * dst,
  1301. uint tpig[[thread_position_in_grid]]) {
  1302. const float4 x = src0[tpig];
  1303. dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
  1304. }
  1305. kernel void kernel_exp_f32(
  1306. device const float * src0,
  1307. device float * dst,
  1308. uint tpig[[thread_position_in_grid]]) {
  1309. dst[tpig] = exp(src0[tpig]);
  1310. }
  1311. kernel void kernel_exp_f32_4(
  1312. device const float4 * src0,
  1313. device float4 * dst,
  1314. uint tpig[[thread_position_in_grid]]) {
  1315. dst[tpig] = exp(src0[tpig]);
  1316. }
  1317. kernel void kernel_reglu_f32(
  1318. constant ggml_metal_kargs_glu & args,
  1319. device const char * src0,
  1320. device const char * src1,
  1321. device char * dst,
  1322. uint tgpig[[threadgroup_position_in_grid]],
  1323. uint tpitg[[thread_position_in_threadgroup]],
  1324. uint ntg[[threads_per_threadgroup]]) {
  1325. device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
  1326. device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
  1327. device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
  1328. for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
  1329. const float x0 = src0_row[i0];
  1330. const float x1 = src1_row[i0];
  1331. dst_row[i0] = x0*x1*(x0 > 0.0f);
  1332. }
  1333. }
  1334. kernel void kernel_geglu_f32(
  1335. constant ggml_metal_kargs_glu & args,
  1336. device const char * src0,
  1337. device const char * src1,
  1338. device char * dst,
  1339. uint tgpig[[threadgroup_position_in_grid]],
  1340. uint tpitg[[thread_position_in_threadgroup]],
  1341. uint ntg[[threads_per_threadgroup]]) {
  1342. device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
  1343. device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
  1344. device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
  1345. for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
  1346. const float x0 = src0_row[i0];
  1347. const float x1 = src1_row[i0];
  1348. const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
  1349. dst_row[i0] = gelu*x1;
  1350. }
  1351. }
  1352. kernel void kernel_swiglu_f32(
  1353. constant ggml_metal_kargs_glu & args,
  1354. device const char * src0,
  1355. device const char * src1,
  1356. device char * dst,
  1357. uint tgpig[[threadgroup_position_in_grid]],
  1358. uint tpitg[[thread_position_in_threadgroup]],
  1359. uint ntg[[threads_per_threadgroup]]) {
  1360. device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
  1361. device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
  1362. device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
  1363. for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
  1364. const float x0 = src0_row[i0];
  1365. const float x1 = src1_row[i0];
  1366. const float silu = x0 / (1.0f + exp(-x0));
  1367. dst_row[i0] = silu*x1;
  1368. }
  1369. }
  1370. kernel void kernel_swiglu_oai_f32(
  1371. constant ggml_metal_kargs_glu & args,
  1372. device const char * src0,
  1373. device const char * src1,
  1374. device char * dst,
  1375. uint tgpig[[threadgroup_position_in_grid]],
  1376. uint tpitg[[thread_position_in_threadgroup]],
  1377. uint ntg[[threads_per_threadgroup]]) {
  1378. device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
  1379. device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
  1380. device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
  1381. for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
  1382. float x0 = src0_row[i0];
  1383. float x1 = src1_row[i0];
  1384. x0 = min(x0, args.limit);
  1385. x1 = max(min(x1, args.limit), -args.limit);
  1386. float out_glu = x0 / (1.0f + exp(-x0 * args.alpha));
  1387. out_glu = out_glu * (1.0f + x1);
  1388. dst_row[i0] = out_glu;
  1389. }
  1390. }
  1391. kernel void kernel_geglu_erf_f32(
  1392. constant ggml_metal_kargs_glu & args,
  1393. device const char * src0,
  1394. device const char * src1,
  1395. device char * dst,
  1396. uint tgpig[[threadgroup_position_in_grid]],
  1397. uint tpitg[[thread_position_in_threadgroup]],
  1398. uint ntg[[threads_per_threadgroup]]) {
  1399. device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
  1400. device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
  1401. device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
  1402. for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
  1403. const float x0 = src0_row[i0];
  1404. const float x1 = src1_row[i0];
  1405. const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV));
  1406. dst_row[i0] = gelu_erf*x1;
  1407. }
  1408. }
  1409. kernel void kernel_geglu_quick_f32(
  1410. constant ggml_metal_kargs_glu & args,
  1411. device const char * src0,
  1412. device const char * src1,
  1413. device char * dst,
  1414. uint tgpig[[threadgroup_position_in_grid]],
  1415. uint tpitg[[thread_position_in_threadgroup]],
  1416. uint ntg[[threads_per_threadgroup]]) {
  1417. device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
  1418. device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
  1419. device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
  1420. for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
  1421. const float x0 = src0_row[i0];
  1422. const float x1 = src1_row[i0];
  1423. const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0)));
  1424. dst_row[i0] = gelu_quick*x1;
  1425. }
  1426. }
  1427. kernel void kernel_op_sum_f32(
  1428. constant ggml_metal_kargs_sum & args,
  1429. device const float * src0,
  1430. device float * dst,
  1431. threadgroup float * shmem_f32 [[threadgroup(0)]],
  1432. uint3 tgpig[[threadgroup_position_in_grid]],
  1433. ushort3 tpitg[[thread_position_in_threadgroup]],
  1434. ushort sgitg[[simdgroup_index_in_threadgroup]],
  1435. ushort tiisg[[thread_index_in_simdgroup]],
  1436. ushort3 ntg[[threads_per_threadgroup]]) {
  1437. if (args.np == 0) {
  1438. return;
  1439. }
  1440. const uint nsg = (ntg.x + 31) / 32;
  1441. float sumf = 0;
  1442. for (uint64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) {
  1443. sumf += src0[i0];
  1444. }
  1445. sumf = simd_sum(sumf);
  1446. if (tiisg == 0) {
  1447. shmem_f32[sgitg] = sumf;
  1448. }
  1449. threadgroup_barrier(mem_flags::mem_threadgroup);
  1450. float total = 0;
  1451. if (sgitg == 0) {
  1452. float v = 0;
  1453. if (tpitg.x < nsg) {
  1454. v = shmem_f32[tpitg.x];
  1455. }
  1456. total = simd_sum(v);
  1457. if (tpitg.x == 0) {
  1458. dst[0] = total;
  1459. }
  1460. }
  1461. }
  1462. template <bool norm>
  1463. kernel void kernel_sum_rows(
  1464. constant ggml_metal_kargs_sum_rows & args,
  1465. device const float * src0,
  1466. device float * dst,
  1467. threadgroup float * shmem_f32 [[threadgroup(0)]],
  1468. uint3 tgpig[[threadgroup_position_in_grid]],
  1469. ushort3 tpitg[[thread_position_in_threadgroup]],
  1470. ushort sgitg[[simdgroup_index_in_threadgroup]],
  1471. ushort tiisg[[thread_index_in_simdgroup]],
  1472. ushort3 ntg[[threads_per_threadgroup]]) {
  1473. int64_t i3 = tgpig.z;
  1474. int64_t i2 = tgpig.y;
  1475. int64_t i1 = tgpig.x;
  1476. if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
  1477. return;
  1478. }
  1479. if (sgitg == 0) {
  1480. shmem_f32[tiisg] = 0.0f;
  1481. }
  1482. device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
  1483. device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
  1484. float sumf = 0;
  1485. for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
  1486. sumf += src_row[i0];
  1487. }
  1488. sumf = simd_sum(sumf);
  1489. threadgroup_barrier(mem_flags::mem_threadgroup);
  1490. if (tiisg == 0) {
  1491. shmem_f32[sgitg] = sumf;
  1492. }
  1493. threadgroup_barrier(mem_flags::mem_threadgroup);
  1494. sumf = shmem_f32[tiisg];
  1495. sumf = simd_sum(sumf);
  1496. if (tpitg.x == 0) {
  1497. dst_row[0] = norm ? sumf / args.ne00 : sumf;
  1498. }
  1499. }
  1500. typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
  1501. template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
  1502. template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
  1503. template<typename T>
  1504. kernel void kernel_cumsum_blk(
  1505. constant ggml_metal_kargs_cumsum_blk & args,
  1506. device const char * src0,
  1507. device char * tmp,
  1508. device char * dst,
  1509. threadgroup char * shmem [[threadgroup(0)]],
  1510. uint3 tgpig[[threadgroup_position_in_grid]],
  1511. ushort3 tpitg[[thread_position_in_threadgroup]],
  1512. ushort sgitg[[simdgroup_index_in_threadgroup]],
  1513. ushort tiisg[[thread_index_in_simdgroup]],
  1514. ushort3 ntg[[threads_per_threadgroup]]) {
  1515. const int ib = tgpig[0]/args.ne01;
  1516. const int i00 = ib*ntg.x;
  1517. const int i01 = tgpig[0]%args.ne01;
  1518. const int i02 = tgpig[1];
  1519. const int i03 = tgpig[2];
  1520. device const float * src0_row = (device const float *) (src0 +
  1521. args.nb01*i01 +
  1522. args.nb02*i02 +
  1523. args.nb03*i03);
  1524. threadgroup float * shmem_f32 = (threadgroup float *) shmem;
  1525. float v = 0.0f;
  1526. if (i00 + tpitg.x < args.ne00) {
  1527. v = src0_row[i00 + tpitg.x];
  1528. }
  1529. float s = simd_prefix_inclusive_sum(v);
  1530. if (tiisg == N_SIMDWIDTH - 1) {
  1531. shmem_f32[sgitg] = s;
  1532. }
  1533. threadgroup_barrier(mem_flags::mem_threadgroup);
  1534. if (sgitg == 0) {
  1535. shmem_f32[tiisg] = simd_prefix_exclusive_sum(shmem_f32[tiisg]);
  1536. }
  1537. threadgroup_barrier(mem_flags::mem_threadgroup);
  1538. s += shmem_f32[sgitg];
  1539. device float * dst_row = (device float *) dst +
  1540. args.ne00*i01 +
  1541. args.ne00*args.ne01*i02 +
  1542. args.ne00*args.ne01*args.ne02*i03;
  1543. if (i00 + tpitg.x < args.ne00) {
  1544. dst_row[i00 + tpitg.x] = s;
  1545. }
  1546. if (args.outb && tpitg.x == ntg.x - 1) {
  1547. device float * tmp_row = (device float *) tmp +
  1548. args.net0*i01 +
  1549. args.net0*args.net1*i02 +
  1550. args.net0*args.net1*args.net2*i03;
  1551. tmp_row[ib] = s;
  1552. }
  1553. }
  1554. typedef decltype(kernel_cumsum_blk<float>) kernel_cumsum_blk_t;
  1555. template [[host_name("kernel_cumsum_blk_f32")]] kernel kernel_cumsum_blk_t kernel_cumsum_blk<float>;
  1556. template<typename T>
  1557. kernel void kernel_cumsum_add(
  1558. constant ggml_metal_kargs_cumsum_add & args,
  1559. device const char * tmp,
  1560. device char * dst,
  1561. uint3 tgpig[[threadgroup_position_in_grid]],
  1562. ushort3 tpitg[[thread_position_in_threadgroup]],
  1563. ushort sgitg[[simdgroup_index_in_threadgroup]],
  1564. ushort tiisg[[thread_index_in_simdgroup]],
  1565. ushort3 ntg[[threads_per_threadgroup]]) {
  1566. const int ib = tgpig[0]/args.ne01;
  1567. if (ib == 0) {
  1568. return;
  1569. }
  1570. const int i00 = ib*ntg.x;
  1571. const int i01 = tgpig[0]%args.ne01;
  1572. const int i02 = tgpig[1];
  1573. const int i03 = tgpig[2];
  1574. device const float * tmp_row = (device const float *) (tmp +
  1575. args.nbt1*i01 +
  1576. args.nbt2*i02 +
  1577. args.nbt3*i03);
  1578. device float * dst_row = (device float *) dst +
  1579. args.ne00*i01 +
  1580. args.ne00*args.ne01*i02 +
  1581. args.ne00*args.ne01*args.ne02*i03;
  1582. if (i00 + tpitg.x < args.ne00) {
  1583. dst_row[i00 + tpitg.x] += tmp_row[ib - 1];
  1584. }
  1585. }
  1586. typedef decltype(kernel_cumsum_add<float>) kernel_cumsum_add_t;
  1587. template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add<float>;
  1588. template<typename T>
  1589. kernel void kernel_soft_max(
  1590. constant ggml_metal_kargs_soft_max & args,
  1591. device const char * src0,
  1592. device const char * src1,
  1593. device const char * src2,
  1594. device char * dst,
  1595. threadgroup float * buf [[threadgroup(0)]],
  1596. uint3 tgpig[[threadgroup_position_in_grid]],
  1597. uint3 tpitg[[thread_position_in_threadgroup]],
  1598. uint sgitg[[simdgroup_index_in_threadgroup]],
  1599. uint tiisg[[thread_index_in_simdgroup]],
  1600. uint3 tptg[[threads_per_threadgroup]]) {
  1601. const int32_t i03 = tgpig.z;
  1602. const int32_t i02 = tgpig.y;
  1603. const int32_t i01 = tgpig.x;
  1604. const int32_t i13 = i03%args.ne13;
  1605. const int32_t i12 = i02%args.ne12;
  1606. const int32_t i11 = i01;
  1607. device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
  1608. device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
  1609. device const float * psrc2 = src2 != src0 ? (device const float *) (src2) : nullptr;
  1610. device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
  1611. float slope = 1.0f;
  1612. // ALiBi
  1613. if (args.max_bias > 0.0f) {
  1614. const int32_t h = i02;
  1615. const float base = h < args.n_head_log2 ? args.m0 : args.m1;
  1616. const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
  1617. slope = pow(base, exp);
  1618. }
  1619. // parallel max
  1620. float lmax = psrc2 ? psrc2[i02] : -INFINITY;
  1621. for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
  1622. lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
  1623. }
  1624. // find the max value in the block
  1625. float max_val = simd_max(lmax);
  1626. if (tptg.x > N_SIMDWIDTH) {
  1627. if (sgitg == 0) {
  1628. buf[tiisg] = -INFINITY;
  1629. }
  1630. threadgroup_barrier(mem_flags::mem_threadgroup);
  1631. if (tiisg == 0) {
  1632. buf[sgitg] = max_val;
  1633. }
  1634. threadgroup_barrier(mem_flags::mem_threadgroup);
  1635. max_val = buf[tiisg];
  1636. max_val = simd_max(max_val);
  1637. }
  1638. // parallel sum
  1639. float lsum = 0.0f;
  1640. for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
  1641. const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
  1642. lsum += exp_psrc0;
  1643. pdst[i00] = exp_psrc0;
  1644. }
  1645. // This barrier fixes a failing test
  1646. // ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
  1647. threadgroup_barrier(mem_flags::mem_none);
  1648. float sum = simd_sum(lsum);
  1649. if (tptg.x > N_SIMDWIDTH) {
  1650. if (sgitg == 0) {
  1651. buf[tiisg] = 0.0f;
  1652. }
  1653. threadgroup_barrier(mem_flags::mem_threadgroup);
  1654. if (tiisg == 0) {
  1655. buf[sgitg] = sum;
  1656. }
  1657. threadgroup_barrier(mem_flags::mem_threadgroup);
  1658. sum = buf[tiisg];
  1659. sum = simd_sum(sum);
  1660. }
  1661. if (psrc2) {
  1662. sum += exp(psrc2[i02] - max_val);
  1663. }
  1664. const float inv_sum = 1.0f/sum;
  1665. for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
  1666. pdst[i00] *= inv_sum;
  1667. }
  1668. }
  1669. template<typename T>
  1670. kernel void kernel_soft_max_4(
  1671. constant ggml_metal_kargs_soft_max & args,
  1672. device const char * src0,
  1673. device const char * src1,
  1674. device const char * src2,
  1675. device char * dst,
  1676. threadgroup float * buf [[threadgroup(0)]],
  1677. uint3 tgpig[[threadgroup_position_in_grid]],
  1678. uint3 tpitg[[thread_position_in_threadgroup]],
  1679. uint sgitg[[simdgroup_index_in_threadgroup]],
  1680. uint tiisg[[thread_index_in_simdgroup]],
  1681. uint3 tptg[[threads_per_threadgroup]]) {
  1682. const int32_t i03 = tgpig.z;
  1683. const int32_t i02 = tgpig.y;
  1684. const int32_t i01 = tgpig.x;
  1685. const int32_t i13 = i03%args.ne13;
  1686. const int32_t i12 = i02%args.ne12;
  1687. const int32_t i11 = i01;
  1688. device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
  1689. device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
  1690. device const float * psrc2 = src2 != src0 ? (device const float * ) (src2) : nullptr;
  1691. device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
  1692. float slope = 1.0f;
  1693. if (args.max_bias > 0.0f) {
  1694. const int32_t h = i02;
  1695. const float base = h < args.n_head_log2 ? args.m0 : args.m1;
  1696. const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
  1697. slope = pow(base, exp);
  1698. }
  1699. // parallel max
  1700. float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
  1701. for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
  1702. lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
  1703. }
  1704. const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
  1705. float max_val = simd_max(lmax);
  1706. if (tptg.x > N_SIMDWIDTH) {
  1707. if (sgitg == 0) {
  1708. buf[tiisg] = -INFINITY;
  1709. }
  1710. threadgroup_barrier(mem_flags::mem_threadgroup);
  1711. if (tiisg == 0) {
  1712. buf[sgitg] = max_val;
  1713. }
  1714. threadgroup_barrier(mem_flags::mem_threadgroup);
  1715. max_val = buf[tiisg];
  1716. max_val = simd_max(max_val);
  1717. }
  1718. // parallel sum
  1719. float4 lsum4 = 0.0f;
  1720. for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
  1721. const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
  1722. lsum4 += exp_psrc4;
  1723. pdst4[i00] = exp_psrc4;
  1724. }
  1725. const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
  1726. // This barrier fixes a failing test
  1727. // ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
  1728. threadgroup_barrier(mem_flags::mem_none);
  1729. float sum = simd_sum(lsum);
  1730. if (tptg.x > N_SIMDWIDTH) {
  1731. if (sgitg == 0) {
  1732. buf[tiisg] = 0.0f;
  1733. }
  1734. threadgroup_barrier(mem_flags::mem_threadgroup);
  1735. if (tiisg == 0) {
  1736. buf[sgitg] = sum;
  1737. }
  1738. threadgroup_barrier(mem_flags::mem_threadgroup);
  1739. sum = buf[tiisg];
  1740. sum = simd_sum(sum);
  1741. }
  1742. if (psrc2) {
  1743. sum += exp(psrc2[i02] - max_val);
  1744. }
  1745. const float inv_sum = 1.0f/sum;
  1746. for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
  1747. pdst4[i00] *= inv_sum;
  1748. }
  1749. }
  1750. typedef decltype(kernel_soft_max<float>) kernel_soft_max_t;
  1751. typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;
  1752. template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max<half>;
  1753. template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max<float>;
  1754. template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
  1755. template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
  1756. // ref: ggml.c:ggml_compute_forward_ssm_conv_f32
  1757. kernel void kernel_ssm_conv_f32_f32(
  1758. constant ggml_metal_kargs_ssm_conv & args,
  1759. device const void * src0,
  1760. device const void * src1,
  1761. device float * dst,
  1762. uint3 tgpig[[threadgroup_position_in_grid]],
  1763. uint3 tpitg[[thread_position_in_threadgroup]],
  1764. uint3 ntg[[threads_per_threadgroup]]) {
  1765. const int64_t ir = tgpig.x;
  1766. const int64_t i2 = tgpig.y;
  1767. const int64_t i3 = tgpig.z;
  1768. const int64_t nc = args.ne10;
  1769. //const int64_t ncs = args.ne00;
  1770. //const int64_t nr = args.ne01;
  1771. //const int64_t n_t = args.ne1;
  1772. //const int64_t n_s = args.ne2;
  1773. device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
  1774. device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
  1775. device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
  1776. float sumf = 0.0f;
  1777. for (int64_t i0 = 0; i0 < nc; ++i0) {
  1778. sumf += s[i0] * c[i0];
  1779. }
  1780. x[0] = sumf;
  1781. }
  1782. kernel void kernel_ssm_conv_f32_f32_4(
  1783. constant ggml_metal_kargs_ssm_conv & args,
  1784. device const void * src0,
  1785. device const void * src1,
  1786. device float * dst,
  1787. uint3 tgpig[[threadgroup_position_in_grid]],
  1788. uint3 tpitg[[thread_position_in_threadgroup]],
  1789. uint3 ntg[[threads_per_threadgroup]]) {
  1790. const int64_t ir = tgpig.x;
  1791. const int64_t i2 = tgpig.y;
  1792. const int64_t i3 = tgpig.z;
  1793. const int64_t nc = args.ne10;
  1794. //const int64_t ncs = args.ne00;
  1795. //const int64_t nr = args.ne01;
  1796. //const int64_t n_t = args.ne1;
  1797. //const int64_t n_s = args.ne2;
  1798. device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
  1799. device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
  1800. device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
  1801. float sumf = 0.0f;
  1802. for (int64_t i0 = 0; i0 < nc/4; ++i0) {
  1803. sumf += dot(s[i0], c[i0]);
  1804. }
  1805. x[0] = sumf;
  1806. }
  1807. // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
  1808. kernel void kernel_ssm_scan_f32(
  1809. constant ggml_metal_kargs_ssm_scan & args,
  1810. device const void * src0,
  1811. device const void * src1,
  1812. device const void * src2,
  1813. device const void * src3,
  1814. device const void * src4,
  1815. device const void * src5,
  1816. device const void * src6,
  1817. device float * dst,
  1818. threadgroup float * shared [[threadgroup(0)]],
  1819. uint3 tgpig[[threadgroup_position_in_grid]],
  1820. ushort3 tpitg[[thread_position_in_threadgroup]],
  1821. ushort sgitg[[simdgroup_index_in_threadgroup]],
  1822. ushort tiisg[[thread_index_in_simdgroup]],
  1823. ushort sgptg[[simdgroups_per_threadgroup]],
  1824. uint3 tgpg[[threadgroups_per_grid]]) {
  1825. constexpr short NW = N_SIMDWIDTH;
  1826. shared[tpitg.x] = 0.0f;
  1827. const int32_t i0 = tpitg.x;
  1828. const int32_t i1 = tgpig.x;
  1829. const int32_t ir = tgpig.y; // current head
  1830. const int32_t i3 = tgpig.z; // current seq
  1831. const int32_t nc = args.d_state;
  1832. const int32_t nr = args.d_inner;
  1833. const int32_t nh = args.n_head;
  1834. const int32_t ng = args.n_group;
  1835. const int32_t n_t = args.n_seq_tokens;
  1836. const int32_t s_off = args.s_off;
  1837. device const int32_t * ids = (device const int32_t *) src6;
  1838. device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
  1839. device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
  1840. const int32_t i = i0 + i1*nc;
  1841. const int32_t g = ir / (nh / ng); // repeat_interleave
  1842. float s0 = s0_buff[i];
  1843. float s = 0.0f;
  1844. device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {ne30, nh}
  1845. const float A0 = A[i0%args.ne30];
  1846. device const float * x = (device const float *)((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i3*args.nb13); // {dim, nh, nt, ns}
  1847. device const float * dt = (device const float *)((device const char *) src2 + ir*args.nb20 + i3*args.nb22); // {nh, nt, ns}
  1848. device const float * B = (device const float *)((device const char *) src4 + g*args.nb41 + i3*args.nb43); // {d_state, ng, nt, ns}
  1849. device const float * C = (device const float *)((device const char *) src5 + g*args.nb51 + i3*args.nb53); // {d_state, ng, nt, ns}
  1850. device float * y = dst + (i1 + ir*(nr) + i3*(n_t*nh*nr)); // {dim, nh, nt, ns}
  1851. for (int i2 = 0; i2 < n_t; i2 += sgptg) {
  1852. threadgroup_barrier(mem_flags::mem_threadgroup);
  1853. for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
  1854. const float dt0 = dt[0];
  1855. const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
  1856. const float x_dt = x[0] * dtsp;
  1857. const float dA = exp(dtsp * A0);
  1858. s = (s0 * dA) + (B[i0] * x_dt);
  1859. const float sumf = simd_sum(s * C[i0]);
  1860. if (tiisg == 0) {
  1861. shared[t*NW + sgitg] = sumf;
  1862. }
  1863. // recurse
  1864. s0 = s;
  1865. x += args.ns12;
  1866. dt += args.ns21;
  1867. B += args.ns42;
  1868. C += args.ns52;
  1869. }
  1870. threadgroup_barrier(mem_flags::mem_threadgroup);
  1871. const float sumf = simd_sum(shared[sgitg*NW + tiisg]);
  1872. if (tiisg == 0 && i2 + sgitg < n_t) {
  1873. y[sgitg*nh*nr] = sumf;
  1874. }
  1875. y += sgptg*nh*nr;
  1876. }
  1877. s_buff[i] = s;
  1878. }
  1879. kernel void kernel_rwkv_wkv6_f32(
  1880. device const float * k,
  1881. device const float * v,
  1882. device const float * r,
  1883. device const float * tf,
  1884. device const float * td,
  1885. device const float * state_in,
  1886. device float * dst,
  1887. constant uint & B,
  1888. constant uint & T,
  1889. constant uint & C,
  1890. constant uint & H,
  1891. uint3 tgpig[[threadgroup_position_in_grid]],
  1892. uint3 tpitg[[thread_position_in_threadgroup]],
  1893. uint3 ntg[[threads_per_threadgroup]]) {
  1894. const uint head_size = 64; // TODO: support head_size = 128
  1895. const uint batch_id = tgpig.x / H;
  1896. const uint head_id = tgpig.x % H;
  1897. const uint tid = tpitg.x;
  1898. if (batch_id >= B || head_id >= H) {
  1899. return;
  1900. }
  1901. const uint state_size = C * head_size;
  1902. const uint n_seq_tokens = T / B;
  1903. threadgroup float _k[head_size];
  1904. threadgroup float _r[head_size];
  1905. threadgroup float _tf[head_size];
  1906. threadgroup float _td[head_size];
  1907. float state[head_size];
  1908. for (uint i = 0; i < head_size; i++) {
  1909. state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
  1910. + i * head_size + tid];
  1911. }
  1912. threadgroup_barrier(mem_flags::mem_threadgroup);
  1913. _tf[tid] = tf[head_id * head_size + tid];
  1914. threadgroup_barrier(mem_flags::mem_threadgroup);
  1915. const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
  1916. const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
  1917. for (uint t = start_t; t < end_t; t += C) {
  1918. threadgroup_barrier(mem_flags::mem_threadgroup);
  1919. _k[tid] = k[t];
  1920. _r[tid] = r[t];
  1921. _td[tid] = td[t];
  1922. threadgroup_barrier(mem_flags::mem_threadgroup);
  1923. const float v_val = v[t];
  1924. float y = 0.0;
  1925. for (uint j = 0; j < head_size; j += 4) {
  1926. float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
  1927. float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
  1928. float4 tf_vec = float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
  1929. float4 td_vec = float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
  1930. float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
  1931. float4 kv = k_vec * v_val;
  1932. float4 temp = tf_vec * kv + s_vec;
  1933. y += dot(r_vec, temp);
  1934. s_vec = s_vec * td_vec + kv;
  1935. state[j] = s_vec[0];
  1936. state[j+1] = s_vec[1];
  1937. state[j+2] = s_vec[2];
  1938. state[j+3] = s_vec[3];
  1939. }
  1940. dst[t] = y;
  1941. }
  1942. for (uint i = 0; i < head_size; i++) {
  1943. dst[T * C + batch_id * state_size + head_id * head_size * head_size
  1944. + i * head_size + tid] = state[i];
  1945. }
  1946. }
  1947. kernel void kernel_rwkv_wkv7_f32(
  1948. device const float * r,
  1949. device const float * w,
  1950. device const float * k,
  1951. device const float * v,
  1952. device const float * a,
  1953. device const float * b,
  1954. device const float * state_in,
  1955. device float * dst,
  1956. constant uint & B,
  1957. constant uint & T,
  1958. constant uint & C,
  1959. constant uint & H,
  1960. uint3 tgpig[[threadgroup_position_in_grid]],
  1961. uint3 tpitg[[thread_position_in_threadgroup]],
  1962. uint3 ntg[[threads_per_threadgroup]]) {
  1963. const uint head_size = 64; // TODO: support head_size = 128
  1964. const uint batch_id = tgpig.x / H;
  1965. const uint head_id = tgpig.x % H;
  1966. const uint tid = tpitg.x;
  1967. if (batch_id >= B || head_id >= H) {
  1968. return;
  1969. }
  1970. const uint state_size = C * head_size;
  1971. const uint n_seq_tokens = T / B;
  1972. threadgroup float _r[head_size];
  1973. threadgroup float _w[head_size];
  1974. threadgroup float _k[head_size];
  1975. threadgroup float _a[head_size];
  1976. threadgroup float _b[head_size];
  1977. float state[head_size];
  1978. for (uint i = 0; i < head_size; i++) {
  1979. state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
  1980. + tid * head_size + i];
  1981. }
  1982. const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
  1983. const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
  1984. for (uint t = start_t; t < end_t; t += C) {
  1985. threadgroup_barrier(mem_flags::mem_threadgroup);
  1986. _r[tid] = r[t];
  1987. _w[tid] = w[t];
  1988. _k[tid] = k[t];
  1989. _a[tid] = a[t];
  1990. _b[tid] = b[t];
  1991. threadgroup_barrier(mem_flags::mem_threadgroup);
  1992. const float v_val = v[t];
  1993. float y = 0.0, sa = 0.0;
  1994. float4 sa_vec(0.0);
  1995. for (uint j = 0; j < head_size; j += 4) {
  1996. float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
  1997. float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
  1998. sa_vec += a_vec * s_vec;
  1999. }
  2000. sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3];
  2001. for (uint j = 0; j < head_size; j += 4) {
  2002. float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
  2003. float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
  2004. float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
  2005. float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
  2006. float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
  2007. float4 kv = k_vec * v_val;
  2008. s_vec = s_vec * w_vec + kv + sa * b_vec;
  2009. y += dot(s_vec, r_vec);
  2010. state[j] = s_vec[0];
  2011. state[j+1] = s_vec[1];
  2012. state[j+2] = s_vec[2];
  2013. state[j+3] = s_vec[3];
  2014. }
  2015. dst[t] = y;
  2016. }
  2017. for (uint i = 0; i < head_size; i++) {
  2018. dst[T * C + batch_id * state_size + head_id * head_size * head_size
  2019. + tid * head_size + i] = state[i];
  2020. }
  2021. }
  2022. kernel void kernel_argmax_f32(
  2023. constant ggml_metal_kargs_argmax & args,
  2024. device const char * src0,
  2025. device char * dst,
  2026. threadgroup char * shmem [[threadgroup(0)]],
  2027. uint tgpig[[threadgroup_position_in_grid]],
  2028. uint tpitg[[thread_position_in_threadgroup]],
  2029. uint sgitg[[simdgroup_index_in_threadgroup]],
  2030. uint tiisg[[thread_index_in_simdgroup]],
  2031. uint ntg[[threads_per_threadgroup]]) {
  2032. device const float * x_row = (device const float *) ((device const char *) src0 + tgpig * args.nb01);
  2033. float lmax = -INFINITY;
  2034. int32_t larg = -1;
  2035. for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
  2036. if (x_row[i00] > lmax) {
  2037. lmax = x_row[i00];
  2038. larg = i00;
  2039. }
  2040. }
  2041. // find the argmax value in the block
  2042. float max_val = simd_max(lmax);
  2043. int32_t arg_val = simd_max(select(-1, larg, lmax == max_val));
  2044. device int32_t * dst_i32 = (device int32_t *) dst;
  2045. threadgroup float * shared_maxval = (threadgroup float *) shmem;
  2046. threadgroup int32_t * shared_argmax = (threadgroup int32_t *) shmem + N_SIMDWIDTH;
  2047. if (ntg > N_SIMDWIDTH) {
  2048. if (sgitg == 0) {
  2049. shared_maxval[tiisg] = -INFINITY;
  2050. shared_argmax[tiisg] = -1;
  2051. }
  2052. threadgroup_barrier(mem_flags::mem_threadgroup);
  2053. if (tiisg == 0) {
  2054. shared_maxval[sgitg] = max_val;
  2055. shared_argmax[sgitg] = arg_val;
  2056. }
  2057. threadgroup_barrier(mem_flags::mem_threadgroup);
  2058. max_val = shared_maxval[tiisg];
  2059. arg_val = shared_argmax[tiisg];
  2060. float max_val_reduced = simd_max(max_val);
  2061. int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced));
  2062. dst_i32[tgpig] = arg_val_reduced;
  2063. return;
  2064. }
  2065. dst_i32[tgpig] = arg_val;
  2066. }
  2067. // F == 1 : norm (no fuse)
  2068. // F == 2 : norm + mul
  2069. // F == 3 : norm + mul + add
  2070. template <typename T, short F>
  2071. kernel void kernel_norm_fuse_impl(
  2072. constant ggml_metal_kargs_norm & args,
  2073. device const char * src0,
  2074. device const char * src1_0,
  2075. device const char * src1_1,
  2076. device char * dst,
  2077. threadgroup float * shmem_f32 [[threadgroup(0)]],
  2078. uint3 tgpig[[threadgroup_position_in_grid]],
  2079. ushort3 tpitg[[thread_position_in_threadgroup]],
  2080. ushort sgitg[[simdgroup_index_in_threadgroup]],
  2081. ushort tiisg[[thread_index_in_simdgroup]],
  2082. ushort3 ntg[[threads_per_threadgroup]]) {
  2083. if (sgitg == 0) {
  2084. shmem_f32[tiisg] = 0.0f;
  2085. }
  2086. const int i01 = tgpig.x;
  2087. const int i02 = tgpig.y;
  2088. const int i03 = tgpig.z;
  2089. device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
  2090. device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
  2091. device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
  2092. T sumft(0.0f);
  2093. float sumf = 0.0f;
  2094. for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
  2095. sumft += x[i00];
  2096. }
  2097. sumf = dot(sumft, T(1.0f));
  2098. sumf = simd_sum(sumf);
  2099. threadgroup_barrier(mem_flags::mem_threadgroup);
  2100. if (tiisg == 0) {
  2101. shmem_f32[sgitg] = sumf;
  2102. }
  2103. threadgroup_barrier(mem_flags::mem_threadgroup);
  2104. sumf = shmem_f32[tiisg];
  2105. sumf = simd_sum(sumf);
  2106. const float mean = sumf/args.ne00;
  2107. device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
  2108. sumf = 0.0f;
  2109. for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
  2110. y[i00] = x[i00] - mean;
  2111. sumf += dot(y[i00], y[i00]);
  2112. }
  2113. sumf = simd_sum(sumf);
  2114. threadgroup_barrier(mem_flags::mem_threadgroup);
  2115. if (tiisg == 0) {
  2116. shmem_f32[sgitg] = sumf;
  2117. }
  2118. threadgroup_barrier(mem_flags::mem_threadgroup);
  2119. sumf = shmem_f32[tiisg];
  2120. sumf = simd_sum(sumf);
  2121. const float variance = sumf/args.ne00;
  2122. const float scale = 1.0f/sqrt(variance + args.eps);
  2123. for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
  2124. if (F == 1) {
  2125. y[i00] = (y[i00]*scale);
  2126. }
  2127. if (F == 2) {
  2128. y[i00] = (y[i00]*scale)*f0[i00];
  2129. }
  2130. if (F == 3) {
  2131. y[i00] = (y[i00]*scale)*f0[i00] + f1[i00];
  2132. }
  2133. }
  2134. }
  2135. typedef decltype(kernel_norm_fuse_impl<float4, 1>) kernel_norm_fuse_t;
  2136. template [[host_name("kernel_norm_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 1>;
  2137. template [[host_name("kernel_norm_mul_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 2>;
  2138. template [[host_name("kernel_norm_mul_add_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 3>;
  2139. template [[host_name("kernel_norm_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 1>;
  2140. template [[host_name("kernel_norm_mul_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 2>;
  2141. template [[host_name("kernel_norm_mul_add_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 3>;
  2142. // F == 1 : rms_norm (no fuse)
  2143. // F == 2 : rms_norm + mul
  2144. // F == 3 : rms_norm + mul + add
  2145. template <typename T, short F>
  2146. kernel void kernel_rms_norm_fuse_impl(
  2147. constant ggml_metal_kargs_norm & args,
  2148. device const char * src0,
  2149. device const char * src1_0,
  2150. device const char * src1_1,
  2151. device char * dst,
  2152. threadgroup float * shmem_f32 [[threadgroup(0)]],
  2153. uint3 tgpig[[threadgroup_position_in_grid]],
  2154. ushort3 tpitg[[thread_position_in_threadgroup]],
  2155. ushort sgitg[[simdgroup_index_in_threadgroup]],
  2156. ushort tiisg[[thread_index_in_simdgroup]],
  2157. ushort3 ntg[[threads_per_threadgroup]]) {
  2158. if (sgitg == 0) {
  2159. shmem_f32[tiisg] = 0.0f;
  2160. }
  2161. const int i01 = tgpig.x;
  2162. const int i02 = tgpig.y;
  2163. const int i03 = tgpig.z;
  2164. device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
  2165. device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
  2166. device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
  2167. float sumf = 0.0f;
  2168. // parallel sum
  2169. for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
  2170. sumf += dot(x[i00], x[i00]);
  2171. }
  2172. sumf = simd_sum(sumf);
  2173. threadgroup_barrier(mem_flags::mem_threadgroup);
  2174. if (tiisg == 0) {
  2175. shmem_f32[sgitg] = sumf;
  2176. }
  2177. threadgroup_barrier(mem_flags::mem_threadgroup);
  2178. sumf = shmem_f32[tiisg];
  2179. sumf = simd_sum(sumf);
  2180. const float mean = sumf/args.ne00;
  2181. const float scale = 1.0f/sqrt(mean + args.eps);
  2182. device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
  2183. for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
  2184. if (F == 1) {
  2185. y[i00] = (x[i00]*scale);
  2186. }
  2187. if (F == 2) {
  2188. y[i00] = (x[i00]*scale)*f0[i00];
  2189. }
  2190. if (F == 3) {
  2191. y[i00] = (x[i00]*scale)*f0[i00] + f1[i00];
  2192. }
  2193. }
  2194. }
  2195. typedef decltype(kernel_rms_norm_fuse_impl<float4, 1>) kernel_rms_norm_fuse_t;
  2196. template [[host_name("kernel_rms_norm_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 1>;
  2197. template [[host_name("kernel_rms_norm_mul_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 2>;
  2198. template [[host_name("kernel_rms_norm_mul_add_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 3>;
  2199. template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 1>;
  2200. template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
  2201. template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;
  2202. kernel void kernel_l2_norm_f32(
  2203. constant ggml_metal_kargs_l2_norm & args,
  2204. device const char * src0,
  2205. device char * dst,
  2206. threadgroup float * shmem_f32 [[threadgroup(0)]],
  2207. uint tgpig[[threadgroup_position_in_grid]],
  2208. ushort tpitg[[thread_position_in_threadgroup]],
  2209. ushort sgitg[[simdgroup_index_in_threadgroup]],
  2210. ushort tiisg[[thread_index_in_simdgroup]],
  2211. ushort ntg[[threads_per_threadgroup]]) {
  2212. if (sgitg == 0) {
  2213. shmem_f32[tiisg] = 0.0f;
  2214. }
  2215. device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
  2216. float sumf = 0.0f;
  2217. // parallel sum
  2218. for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
  2219. sumf += dot(x[i00], x[i00]);
  2220. }
  2221. sumf = simd_sum(sumf);
  2222. threadgroup_barrier(mem_flags::mem_threadgroup);
  2223. if (tiisg == 0) {
  2224. shmem_f32[sgitg] = sumf;
  2225. }
  2226. threadgroup_barrier(mem_flags::mem_threadgroup);
  2227. sumf = shmem_f32[tiisg];
  2228. sumf = simd_sum(sumf);
  2229. const float scale = 1.0f/sqrt(max(sumf, args.eps));
  2230. device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
  2231. for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
  2232. y[i00] = x[i00] * scale;
  2233. }
  2234. }
  2235. kernel void kernel_group_norm_f32(
  2236. constant ggml_metal_kargs_group_norm & args,
  2237. device const float * src0,
  2238. device float * dst,
  2239. threadgroup float * buf [[threadgroup(0)]],
  2240. uint tgpig[[threadgroup_position_in_grid]],
  2241. uint tpitg[[thread_position_in_threadgroup]],
  2242. uint sgitg[[simdgroup_index_in_threadgroup]],
  2243. uint tiisg[[thread_index_in_simdgroup]],
  2244. uint ntg[[threads_per_threadgroup]]) {
  2245. const int64_t ne = args.ne00*args.ne01*args.ne02;
  2246. const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.ngrp - 1) / args.ngrp);
  2247. int start = tgpig * gs;
  2248. int end = start + gs;
  2249. start += tpitg;
  2250. if (end >= ne) {
  2251. end = ne;
  2252. }
  2253. float tmp = 0.0f; // partial sum for thread in warp
  2254. for (int j = start; j < end; j += ntg) {
  2255. tmp += src0[j];
  2256. }
  2257. threadgroup_barrier(mem_flags::mem_threadgroup);
  2258. tmp = simd_sum(tmp);
  2259. if (ntg > N_SIMDWIDTH) {
  2260. if (sgitg == 0) {
  2261. buf[tiisg] = 0.0f;
  2262. }
  2263. threadgroup_barrier(mem_flags::mem_threadgroup);
  2264. if (tiisg == 0) {
  2265. buf[sgitg] = tmp;
  2266. }
  2267. threadgroup_barrier(mem_flags::mem_threadgroup);
  2268. tmp = buf[tiisg];
  2269. tmp = simd_sum(tmp);
  2270. }
  2271. const float mean = tmp / gs;
  2272. tmp = 0.0f;
  2273. for (int j = start; j < end; j += ntg) {
  2274. float xi = src0[j] - mean;
  2275. dst[j] = xi;
  2276. tmp += xi * xi;
  2277. }
  2278. tmp = simd_sum(tmp);
  2279. if (ntg > N_SIMDWIDTH) {
  2280. if (sgitg == 0) {
  2281. buf[tiisg] = 0.0f;
  2282. }
  2283. threadgroup_barrier(mem_flags::mem_threadgroup);
  2284. if (tiisg == 0) {
  2285. buf[sgitg] = tmp;
  2286. }
  2287. threadgroup_barrier(mem_flags::mem_threadgroup);
  2288. tmp = buf[tiisg];
  2289. tmp = simd_sum(tmp);
  2290. }
  2291. const float variance = tmp / gs;
  2292. const float scale = 1.0f/sqrt(variance + args.eps);
  2293. for (int j = start; j < end; j += ntg) {
  2294. dst[j] *= scale;
  2295. }
  2296. }
  2297. // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
  2298. // il indicates where the q4 quants begin (0 or QK4_0/4)
  2299. // we assume that the yl's have been multiplied with the appropriate scale factor
  2300. // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
  2301. inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
  2302. float d = qb_curr->d;
  2303. float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
  2304. device const uint16_t * qs = ((device const uint16_t *) qb_curr + 1 + il/2);
  2305. for (int i = 0; i < 8; i += 2) {
  2306. acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
  2307. acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
  2308. acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
  2309. acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
  2310. }
  2311. return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]);
  2312. }
  2313. // function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
  2314. // il indicates where the q4 quants begin (0 or QK4_0/4)
  2315. // we assume that the yl's have been multiplied with the appropriate scale factor
  2316. // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
  2317. inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
  2318. float d = qb_curr->d;
  2319. float m = qb_curr->m;
  2320. float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
  2321. device const uint16_t * qs = ((device const uint16_t *) qb_curr + 2 + il/2);
  2322. for (int i = 0; i < 8; i+=2) {
  2323. acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
  2324. acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
  2325. acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
  2326. acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
  2327. }
  2328. return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
  2329. }
  2330. // function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
  2331. // il indicates where the q5 quants begin (0 or QK5_0/4)
  2332. // we assume that the yl's have been multiplied with the appropriate scale factor
  2333. // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
  2334. inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
  2335. float d = qb_curr->d;
  2336. float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
  2337. device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
  2338. const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
  2339. for (int i = 0; i < 8; i+=2) {
  2340. acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010));
  2341. acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
  2342. acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
  2343. acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
  2344. }
  2345. return d * (sumy * -16.f + acc[0] + acc[1] + acc[2] + acc[3]);
  2346. }
  2347. // function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
  2348. // il indicates where the q5 quants begin (0 or QK5_1/4)
  2349. // we assume that the yl's have been multiplied with the appropriate scale factor
  2350. // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
  2351. inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
  2352. float d = qb_curr->d;
  2353. float m = qb_curr->m;
  2354. float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
  2355. device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
  2356. const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
  2357. for (int i = 0; i < 8; i+=2) {
  2358. acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010));
  2359. acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
  2360. acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
  2361. acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
  2362. }
  2363. return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
  2364. }
  2365. template<short NR0>
  2366. static inline void helper_mv_reduce_and_write(
  2367. device float * dst_f32,
  2368. float sumf[NR0],
  2369. const int r0,
  2370. const int ne01,
  2371. ushort tiisg,
  2372. ushort sgitg,
  2373. threadgroup char * shmem) {
  2374. constexpr short NW = N_SIMDWIDTH;
  2375. threadgroup float * shmem_f32[NR0];
  2376. for (short row = 0; row < NR0; ++row) {
  2377. shmem_f32[row] = (threadgroup float *) shmem + NW*row;
  2378. if (sgitg == 0) {
  2379. shmem_f32[row][tiisg] = 0.0f;
  2380. }
  2381. sumf[row] = simd_sum(sumf[row]);
  2382. }
  2383. threadgroup_barrier(mem_flags::mem_threadgroup);
  2384. for (short row = 0; row < NR0; ++row) {
  2385. if (tiisg == 0) {
  2386. shmem_f32[row][sgitg] = sumf[row];
  2387. }
  2388. }
  2389. threadgroup_barrier(mem_flags::mem_threadgroup);
  2390. for (short row = 0; row < NR0 && r0 + row < ne01; ++row) {
  2391. float tot = simd_sum(shmem_f32[row][tiisg]);
  2392. if (tiisg == 0 && sgitg == 0) {
  2393. dst_f32[r0 + row] = tot;
  2394. }
  2395. }
  2396. }
  2397. constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]];
  2398. constant short FC_mul_mv_nxpsg [[function_constant(FC_MUL_MV + 1)]];
  2399. template<typename block_q_type, short NR0, typename args_t>
  2400. void mul_vec_q_n_f32_impl(
  2401. args_t args,
  2402. device const char * src0,
  2403. device const char * src1,
  2404. device char * dst,
  2405. threadgroup char * shmem,
  2406. uint3 tgpig,
  2407. ushort tiisg,
  2408. ushort sgitg) {
  2409. const short NSG = FC_mul_mv_nsg;
  2410. constexpr short NW = N_SIMDWIDTH;
  2411. constexpr short NQ = 16;
  2412. const int nb = args.ne00/QK4_0;
  2413. const int r0 = (tgpig.x*NSG + sgitg)*NR0;
  2414. //const int r0 = tgpig.x*NR0;
  2415. const int r1 = tgpig.y;
  2416. const int im = tgpig.z;
  2417. const uint i12 = im%args.ne12;
  2418. const uint i13 = im/args.ne12;
  2419. //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
  2420. const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
  2421. //device const block_q_type * x = (device const block_q_type *) (src0 + offset0);
  2422. device const float * y = (device const float *) (src1 + offset1);
  2423. // pointers to src0 rows
  2424. device const block_q_type * ax[NR0];
  2425. FOR_UNROLL (int row = 0; row < NR0; ++row) {
  2426. const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
  2427. ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
  2428. }
  2429. float sumf[NR0] = {0.f};
  2430. const short ix = (tiisg/(NW/NQ));
  2431. const short il = (tiisg%(NW/NQ))*8;
  2432. //const int ib0 = sgitg*NQ + ix;
  2433. const int ib0 = ix;
  2434. float yl[16]; // src1 vector cache
  2435. //device const float * yb = y + ix*QK4_0 + il;
  2436. device const float * yb = y + ib0*QK4_0 + il;
  2437. // each thread in a SIMD group deals with half a block.
  2438. //for (int ib = ib0; ib < nb; ib += NSG*NQ) {
  2439. for (int ib = ib0; ib < nb; ib += NQ) {
  2440. float sumy[2] = { 0.f, 0.f };
  2441. FOR_UNROLL (short i = 0; i < 8; i += 2) {
  2442. sumy[0] += yb[i + 0] + yb[i + 1];
  2443. yl[i + 0] = yb[i + 0];
  2444. yl[i + 1] = yb[i + 1]/256.f;
  2445. sumy[1] += yb[i + 16] + yb[i + 17];
  2446. yl[i + 8] = yb[i + 16]/16.f;
  2447. yl[i + 9] = yb[i + 17]/4096.f;
  2448. }
  2449. FOR_UNROLL (short row = 0; row < NR0; row++) {
  2450. sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
  2451. }
  2452. yb += QK4_0 * 16;
  2453. //yb += NSG*NQ*QK4_0;
  2454. }
  2455. device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
  2456. //helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
  2457. for (int row = 0; row < NR0; ++row) {
  2458. const float tot = simd_sum(sumf[row]);
  2459. if (tiisg == 0 && r0 + row < args.ne01) {
  2460. dst_f32[r0 + row] = tot;
  2461. }
  2462. }
  2463. }
  2464. kernel void kernel_mul_mv_q4_0_f32(
  2465. constant ggml_metal_kargs_mul_mv & args,
  2466. device const char * src0,
  2467. device const char * src1,
  2468. device char * dst,
  2469. threadgroup char * shmem [[threadgroup(0)]],
  2470. uint3 tgpig[[threadgroup_position_in_grid]],
  2471. ushort tiisg[[thread_index_in_simdgroup]],
  2472. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  2473. mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
  2474. }
  2475. kernel void kernel_mul_mv_q4_1_f32(
  2476. constant ggml_metal_kargs_mul_mv & args,
  2477. device const char * src0,
  2478. device const char * src1,
  2479. device char * dst,
  2480. threadgroup char * shmem [[threadgroup(0)]],
  2481. uint3 tgpig[[threadgroup_position_in_grid]],
  2482. ushort tiisg[[thread_index_in_simdgroup]],
  2483. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  2484. mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
  2485. }
  2486. kernel void kernel_mul_mv_q5_0_f32(
  2487. constant ggml_metal_kargs_mul_mv & args,
  2488. device const char * src0,
  2489. device const char * src1,
  2490. device char * dst,
  2491. threadgroup char * shmem [[threadgroup(0)]],
  2492. uint3 tgpig[[threadgroup_position_in_grid]],
  2493. ushort tiisg[[thread_index_in_simdgroup]],
  2494. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  2495. mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
  2496. }
  2497. kernel void kernel_mul_mv_q5_1_f32(
  2498. constant ggml_metal_kargs_mul_mv & args,
  2499. device const char * src0,
  2500. device const char * src1,
  2501. device char * dst,
  2502. threadgroup char * shmem [[threadgroup(0)]],
  2503. uint3 tgpig[[threadgroup_position_in_grid]],
  2504. ushort tiisg[[thread_index_in_simdgroup]],
  2505. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  2506. mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
  2507. }
  2508. template<short NR0, typename args_t>
  2509. void kernel_mul_mv_q8_0_f32_impl(
  2510. args_t args,
  2511. device const char * src0,
  2512. device const char * src1,
  2513. device char * dst,
  2514. threadgroup char * shmem,
  2515. uint3 tgpig,
  2516. ushort tiisg,
  2517. ushort sgitg) {
  2518. const short NSG = FC_mul_mv_nsg;
  2519. constexpr short NW = N_SIMDWIDTH;
  2520. constexpr short NQ = 8;
  2521. const int nb = args.ne00/QK8_0;
  2522. const int r0 = tgpig.x*NR0;
  2523. const int r1 = tgpig.y;
  2524. const int im = tgpig.z;
  2525. const uint i12 = im%args.ne12;
  2526. const uint i13 = im/args.ne12;
  2527. //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
  2528. const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
  2529. //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0);
  2530. device const float * y = (device const float *) (src1 + offset1);
  2531. // pointers to src0 rows
  2532. device const block_q8_0 * ax[NR0];
  2533. FOR_UNROLL (short row = 0; row < NR0; ++row) {
  2534. const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
  2535. ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
  2536. }
  2537. float sumf[NR0] = { 0.f };
  2538. const short ix = tiisg/(NW/NQ);
  2539. const short il = tiisg%(NW/NQ);
  2540. const int ib0 = sgitg*NQ + ix;
  2541. float yl[NQ];
  2542. device const float * yb = y + ib0*QK8_0 + il*NQ;
  2543. // each thread in a SIMD group deals with NQ quants at a time
  2544. for (int ib = ib0; ib < nb; ib += NSG*NQ) {
  2545. for (short i = 0; i < NQ; ++i) {
  2546. yl[i] = yb[i];
  2547. }
  2548. for (short row = 0; row < NR0; row++) {
  2549. device const int8_t * qs = ax[row][ib].qs + il*NQ;
  2550. float sumq = 0.f;
  2551. FOR_UNROLL (short i = 0; i < NQ; ++i) {
  2552. sumq += qs[i] * yl[i];
  2553. }
  2554. sumf[row] += sumq*ax[row][ib].d;
  2555. }
  2556. yb += NSG*NQ*QK8_0;
  2557. }
  2558. device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
  2559. helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
  2560. }
  2561. [[host_name("kernel_mul_mv_q8_0_f32")]]
  2562. kernel void kernel_mul_mv_q8_0_f32(
  2563. constant ggml_metal_kargs_mul_mv & args,
  2564. device const char * src0,
  2565. device const char * src1,
  2566. device char * dst,
  2567. threadgroup char * shmem [[threadgroup(0)]],
  2568. uint3 tgpig[[threadgroup_position_in_grid]],
  2569. ushort tiisg[[thread_index_in_simdgroup]],
  2570. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  2571. kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
  2572. }
  2573. // mat-vec kernel processing in chunks of float4
  2574. // chpb - chunks per quantization block
  2575. template<short r1ptg, typename q_t, short chpb, void (*deq_t4)(device const q_t *, short, thread float4 &) >
  2576. void kernel_mul_mv_ext_q4_f32_impl(
  2577. constant ggml_metal_kargs_mul_mv_ext & args,
  2578. device const char * src0,
  2579. device const char * src1,
  2580. device char * dst,
  2581. uint3 tgpig[[threadgroup_position_in_grid]],
  2582. ushort tiisg[[thread_index_in_simdgroup]],
  2583. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  2584. const short NSG = FC_mul_mv_nsg;
  2585. const short nxpsg = FC_mul_mv_nxpsg;
  2586. const short chpt = 4; // chunks per thread
  2587. //const short nxpsg = (32);
  2588. const short nypsg = (32/nxpsg);
  2589. const short tx = tiisg%nxpsg;
  2590. const short ty = tiisg/nxpsg;
  2591. const int i01 = tgpig.x*(nypsg*NSG) + nypsg*sgitg + ty;
  2592. const int i11 = tgpig.y*r1ptg;
  2593. const int i1m = tgpig.z;
  2594. const int i12 = i1m%args.ne12;
  2595. const int i13 = i1m/args.ne12;
  2596. const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
  2597. const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
  2598. device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
  2599. device const float4 * y4[r1ptg];
  2600. for (int ir1 = 0; ir1 < r1ptg; ++ir1) {
  2601. y4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4 *) src1;
  2602. }
  2603. float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };
  2604. short cch = tx%chpb; // current chunk index
  2605. for (int ich = tx; 4*ich < args.ne00; ich += chpt*nxpsg) {
  2606. float4 lx[chpt];
  2607. #pragma unroll(chpt)
  2608. for (short ch = 0; ch < chpt; ++ch) {
  2609. deq_t4(xq, cch, lx[ch]);
  2610. cch += nxpsg;
  2611. if (cch >= chpb) {
  2612. xq += cch/chpb;
  2613. cch %= chpb;
  2614. }
  2615. }
  2616. #pragma unroll(chpt)
  2617. for (short ch = 0; ch < chpt; ++ch) {
  2618. #pragma unroll(r1ptg)
  2619. for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
  2620. sumf[ir1] += dot(lx[ch], y4[ir1][ch*nxpsg]);
  2621. }
  2622. }
  2623. #pragma unroll(r1ptg)
  2624. for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
  2625. y4[ir1] += chpt*nxpsg;
  2626. }
  2627. }
  2628. // reduce only the threads in each row
  2629. for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
  2630. if (nxpsg >= 32) {
  2631. sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);
  2632. }
  2633. if (nxpsg >= 16) {
  2634. sumf[ir1] += simd_shuffle_down(sumf[ir1], 8);
  2635. }
  2636. if (nxpsg >= 8) {
  2637. sumf[ir1] += simd_shuffle_down(sumf[ir1], 4);
  2638. }
  2639. if (nxpsg >= 4) {
  2640. sumf[ir1] += simd_shuffle_down(sumf[ir1], 2);
  2641. }
  2642. if (nxpsg >= 2) {
  2643. sumf[ir1] += simd_shuffle_down(sumf[ir1], 1);
  2644. }
  2645. //sumf[ir1] = simd_sum(sumf[ir1]);
  2646. }
  2647. if (tx == 0) {
  2648. for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) {
  2649. device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0;
  2650. if (i01 < args.ne01) {
  2651. dst_f32[i01] = sumf[ir1];
  2652. }
  2653. }
  2654. }
  2655. }
  2656. // mat-vec kernel processing in chunks of float4x4
  2657. template<short r1ptg, typename q_t, short chpb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &) >
  2658. void kernel_mul_mv_ext_q4x4_f32_impl(
  2659. constant ggml_metal_kargs_mul_mv_ext & args,
  2660. device const char * src0,
  2661. device const char * src1,
  2662. device char * dst,
  2663. uint3 tgpig[[threadgroup_position_in_grid]],
  2664. ushort tiisg[[thread_index_in_simdgroup]],
  2665. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  2666. const short NSG = FC_mul_mv_nsg;
  2667. const short nxpsg = FC_mul_mv_nxpsg;
  2668. const short chpt = 1;
  2669. //const short nxpsg = (32);
  2670. const short nypsg = (32/nxpsg);
  2671. const short tx = tiisg%nxpsg;
  2672. const short ty = tiisg/nxpsg;
  2673. const int i01 = tgpig.x*(nypsg*NSG) + nypsg*sgitg + ty;
  2674. const int i11 = tgpig.y*r1ptg;
  2675. const int i1m = tgpig.z;
  2676. const int i12 = i1m%args.ne12;
  2677. const int i13 = i1m/args.ne12;
  2678. const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
  2679. const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
  2680. device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
  2681. device const float4x4 * y4x4[r1ptg];
  2682. for (int ir1 = 0; ir1 < r1ptg; ++ir1) {
  2683. y4x4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4x4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4x4 *) src1;
  2684. }
  2685. float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };
  2686. short cch = tx%chpb;
  2687. for (int ich = tx; 16*ich < args.ne00; ich += chpt*nxpsg) {
  2688. float4x4 lx[chpt];
  2689. #pragma unroll(chpt)
  2690. for (short ch = 0; ch < chpt; ++ch) {
  2691. deq_t4x4(xq, cch, lx[ch]);
  2692. cch += nxpsg;
  2693. if (cch >= chpb) {
  2694. xq += cch/chpb;
  2695. cch %= chpb;
  2696. }
  2697. }
  2698. #pragma unroll(chpt)
  2699. for (short ch = 0; ch < chpt; ++ch) {
  2700. #pragma unroll(r1ptg)
  2701. for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
  2702. sumf[ir1] +=
  2703. dot(lx[ch][0], y4x4[ir1][ch*nxpsg][0]) +
  2704. dot(lx[ch][1], y4x4[ir1][ch*nxpsg][1]) +
  2705. dot(lx[ch][2], y4x4[ir1][ch*nxpsg][2]) +
  2706. dot(lx[ch][3], y4x4[ir1][ch*nxpsg][3]);
  2707. }
  2708. }
  2709. #pragma unroll(r1ptg)
  2710. for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
  2711. y4x4[ir1] += chpt*nxpsg;
  2712. }
  2713. }
  2714. for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
  2715. if (nxpsg >= 32) {
  2716. sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);
  2717. }
  2718. if (nxpsg >= 16) {
  2719. sumf[ir1] += simd_shuffle_down(sumf[ir1], 8);
  2720. }
  2721. if (nxpsg >= 8) {
  2722. sumf[ir1] += simd_shuffle_down(sumf[ir1], 4);
  2723. }
  2724. if (nxpsg >= 4) {
  2725. sumf[ir1] += simd_shuffle_down(sumf[ir1], 2);
  2726. }
  2727. if (nxpsg >= 2) {
  2728. sumf[ir1] += simd_shuffle_down(sumf[ir1], 1);
  2729. }
  2730. //sumf[ir1] = simd_sum(sumf[ir1]);
  2731. }
  2732. if (tx == 0) {
  2733. for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) {
  2734. device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0;
  2735. if (i01 < args.ne01) {
  2736. dst_f32[i01] = sumf[ir1];
  2737. }
  2738. }
  2739. }
  2740. }
  2741. // dispatchers needed for compile-time nxpsg
  2742. // epb - elements per quantization block
  2743. template<short r1ptg, typename q_t, short epb, void (*deq_t4)(device const q_t *, short, thread float4 &)>
  2744. kernel void kernel_mul_mv_ext_q4_f32_disp(
  2745. constant ggml_metal_kargs_mul_mv_ext & args,
  2746. device const char * src0,
  2747. device const char * src1,
  2748. device char * dst,
  2749. uint3 tgpig[[threadgroup_position_in_grid]],
  2750. ushort tiisg[[thread_index_in_simdgroup]],
  2751. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  2752. kernel_mul_mv_ext_q4_f32_impl<r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg);
  2753. }
  2754. template<short r1ptg, typename q_t, short epb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &)>
  2755. kernel void kernel_mul_mv_ext_q4x4_f32_disp(
  2756. constant ggml_metal_kargs_mul_mv_ext & args,
  2757. device const char * src0,
  2758. device const char * src1,
  2759. device char * dst,
  2760. uint3 tgpig[[threadgroup_position_in_grid]],
  2761. ushort tiisg[[thread_index_in_simdgroup]],
  2762. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  2763. kernel_mul_mv_ext_q4x4_f32_impl<r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg);
  2764. }
  2765. typedef decltype(kernel_mul_mv_ext_q4_f32_disp <2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t;
  2766. typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>) mul_mv_ext_q4x4_f32_t;
  2767. template [[host_name("kernel_mul_mv_ext_f32_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, float4, 4, dequantize_f32_t4>;
  2768. template [[host_name("kernel_mul_mv_ext_f32_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, float4, 4, dequantize_f32_t4>;
  2769. template [[host_name("kernel_mul_mv_ext_f32_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, float4, 4, dequantize_f32_t4>;
  2770. template [[host_name("kernel_mul_mv_ext_f32_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, float4, 4, dequantize_f32_t4>;
  2771. template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, half4, 4, dequantize_f16_t4>;
  2772. template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, half4, 4, dequantize_f16_t4>;
  2773. template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>;
  2774. template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, half4, 4, dequantize_f16_t4>;
  2775. template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>;
  2776. template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>;
  2777. template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>;
  2778. template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_0, 32, dequantize_q4_0_t4>;
  2779. template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_1, 32, dequantize_q4_1_t4>;
  2780. template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_1, 32, dequantize_q4_1_t4>;
  2781. template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_1, 32, dequantize_q4_1_t4>;
  2782. template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_1, 32, dequantize_q4_1_t4>;
  2783. template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_0, 32, dequantize_q5_0_t4>;
  2784. template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_0, 32, dequantize_q5_0_t4>;
  2785. template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_0, 32, dequantize_q5_0_t4>;
  2786. template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_0, 32, dequantize_q5_0_t4>;
  2787. template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_1, 32, dequantize_q5_1_t4>;
  2788. template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_1, 32, dequantize_q5_1_t4>;
  2789. template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_1, 32, dequantize_q5_1_t4>;
  2790. template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_1, 32, dequantize_q5_1_t4>;
  2791. template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q8_0, 32, dequantize_q8_0_t4>;
  2792. template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q8_0, 32, dequantize_q8_0_t4>;
  2793. template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0, 32, dequantize_q8_0_t4>;
  2794. template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0, 32, dequantize_q8_0_t4>;
  2795. template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_mxfp4, 32, dequantize_mxfp4_t4>;
  2796. template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_mxfp4, 32, dequantize_mxfp4_t4>;
  2797. template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_mxfp4, 32, dequantize_mxfp4_t4>;
  2798. template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_mxfp4, 32, dequantize_mxfp4_t4>;
  2799. template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
  2800. template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
  2801. template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
  2802. template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
  2803. template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>;
  2804. template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q4_K, 256, dequantize_q4_K>;
  2805. template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q4_K, 256, dequantize_q4_K>;
  2806. template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q4_K, 256, dequantize_q4_K>;
  2807. template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q5_K, 256, dequantize_q5_K>;
  2808. template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q5_K, 256, dequantize_q5_K>;
  2809. template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q5_K, 256, dequantize_q5_K>;
  2810. template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q5_K, 256, dequantize_q5_K>;
  2811. template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q6_K, 256, dequantize_q6_K>;
  2812. template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q6_K, 256, dequantize_q6_K>;
  2813. template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>;
  2814. template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>;
  2815. template<typename T0, typename T1, short NR0, typename args_t>
  2816. void kernel_mul_mv_t_t_impl(
  2817. args_t args,
  2818. device const char * src0,
  2819. device const char * src1,
  2820. device char * dst,
  2821. threadgroup char * shmem,
  2822. uint3 tgpig,
  2823. ushort tiisg,
  2824. ushort sgitg) {
  2825. const short NSG = FC_mul_mv_nsg;
  2826. constexpr short NW = N_SIMDWIDTH;
  2827. constexpr short NB = 32;
  2828. constexpr short NF = 8;
  2829. const int nb = args.ne00/NB;
  2830. const int r0 = tgpig.x*NR0;
  2831. const int r1 = tgpig.y;
  2832. const int im = tgpig.z;
  2833. const uint i12 = im%args.ne12;
  2834. const uint i13 = im/args.ne12;
  2835. //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
  2836. const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
  2837. //device const T0 * x = (device const T0 *) (src0 + offset0);
  2838. device const T1 * y = (device const T1 *) (src1 + offset1);
  2839. // pointers to src0 rows
  2840. device const T0 * ax [NR0];
  2841. FOR_UNROLL (short row = 0; row < NR0; ++row) {
  2842. const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
  2843. ax[row] = (device const T0 *) ((device char *) src0 + offset0);
  2844. }
  2845. float sumf[NR0] = { 0.f };
  2846. const short ix = tiisg/(NW/NF);
  2847. const short il = tiisg%(NW/NF);
  2848. const int ib0 = sgitg*NF + ix;
  2849. T1 yl[NF];
  2850. device const T1 * yb = y + (ib0*NB + il*NF);
  2851. for (int ib = ib0; ib < nb; ib += NSG*NF) {
  2852. for (short i = 0; i < NF; ++i) {
  2853. yl[i] = yb[i];
  2854. }
  2855. for (short row = 0; row < NR0; row++) {
  2856. device const T0 * xb = ax[row] + (ib*NB + il*NF);
  2857. float sumq = 0.f;
  2858. FOR_UNROLL (short i = 0; i < NF; ++i) {
  2859. sumq += xb[i] * yl[i];
  2860. }
  2861. sumf[row] += sumq;
  2862. }
  2863. yb += NSG*NF*NW;
  2864. }
  2865. for (int i = nb*NB + sgitg*NW + tiisg; i < args.ne00; i += NW*NSG) {
  2866. for (short row = 0; row < NR0; row++) {
  2867. sumf[row] += ax[row][i] * y[i];
  2868. }
  2869. }
  2870. device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
  2871. helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
  2872. }
  2873. template<typename T0, typename T1, typename args_t>
  2874. void kernel_mul_mv_t_t_disp(
  2875. args_t args,
  2876. device const char * src0,
  2877. device const char * src1,
  2878. device char * dst,
  2879. threadgroup char * shmem,
  2880. uint3 tgpig,
  2881. ushort tiisg,
  2882. ushort sgitg) {
  2883. switch (args.nr0) {
  2884. //case 1: kernel_mul_mv_t_t_impl<T0, T1, 1, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
  2885. case 2: kernel_mul_mv_t_t_impl<T0, T1, 2, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
  2886. //case 3: kernel_mul_mv_t_t_impl<T0, T1, 3, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
  2887. //case 4: kernel_mul_mv_t_t_impl<T0, T1, 4, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
  2888. }
  2889. }
  2890. template<typename T0, typename T1>
  2891. kernel void kernel_mul_mv_t_t(
  2892. constant ggml_metal_kargs_mul_mv & args,
  2893. device const char * src0,
  2894. device const char * src1,
  2895. device char * dst,
  2896. threadgroup char * shmem [[threadgroup(0)]],
  2897. uint3 tgpig[[threadgroup_position_in_grid]],
  2898. ushort tiisg[[thread_index_in_simdgroup]],
  2899. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  2900. kernel_mul_mv_t_t_disp<T0, T1, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
  2901. }
  2902. typedef decltype(kernel_mul_mv_t_t<half, half>) mul_mv_t_t;
  2903. template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<float, float>;
  2904. template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, float>;
  2905. template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, half>;
  2906. #if defined(GGML_METAL_HAS_BF16)
  2907. template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, float>;
  2908. template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, bfloat>;
  2909. #endif
  2910. template<typename T0, typename T04, typename T1, typename T14, short NR0, typename args_t>
  2911. void kernel_mul_mv_t_t_4_impl(
  2912. args_t args,
  2913. device const char * src0,
  2914. device const char * src1,
  2915. device char * dst,
  2916. threadgroup char * shmem,
  2917. uint3 tgpig,
  2918. ushort tiisg,
  2919. ushort sgitg) {
  2920. const short NSG = FC_mul_mv_nsg;
  2921. constexpr short NW = N_SIMDWIDTH;
  2922. constexpr short NB = 32;
  2923. constexpr short NF = 16;
  2924. constexpr short NF4 = NF/4;
  2925. const int nb = args.ne00/NB;
  2926. const int r0 = tgpig.x*NR0;
  2927. const int r1 = tgpig.y;
  2928. const int im = tgpig.z;
  2929. const uint i12 = im%args.ne12;
  2930. const uint i13 = im/args.ne12;
  2931. //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
  2932. const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
  2933. device const T1 * y = (device const T1 *) (src1 + offset1);
  2934. device const T14 * y4 = (device const T14 *) (src1 + offset1);
  2935. // pointers to src0 rows
  2936. device const T0 * ax [NR0];
  2937. device const T04 * ax4[NR0];
  2938. FOR_UNROLL (short row = 0; row < NR0; ++row) {
  2939. const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
  2940. ax [row] = (device const T0 *) ((device char *) src0 + offset0);
  2941. ax4[row] = (device const T04 *) ((device char *) src0 + offset0);
  2942. }
  2943. float sumf[NR0] = { 0.f };
  2944. const short ix = tiisg/(NW/NF);
  2945. const short il = tiisg%(NW/NF);
  2946. const int ib0 = sgitg*NF + ix;
  2947. T14 yl4[NF4];
  2948. device const T14 * yb4 = y4 + (ib0*NB + il*NF)/4;
  2949. for (int ib = ib0; ib < nb; ib += NSG*NF) {
  2950. for (short i = 0; i < NF4; ++i) {
  2951. yl4[i] = yb4[i];
  2952. }
  2953. for (short row = 0; row < NR0; row++) {
  2954. device const T04 * xb4 = ax4[row] + (ib*NB + il*NF)/4;
  2955. float sumq = 0.f;
  2956. FOR_UNROLL (short i = 0; i < NF4; ++i) {
  2957. sumq += dot(float4(xb4[i]), float4(yl4[i]));
  2958. }
  2959. sumf[row] += sumq;
  2960. }
  2961. yb4 += NSG*NF*NW/4;
  2962. }
  2963. for (int i = nb*NB + sgitg*NW + tiisg; i < args.ne00; i += NW*NSG) {
  2964. for (short row = 0; row < NR0; row++) {
  2965. sumf[row] += ax[row][i] * y[i];
  2966. }
  2967. }
  2968. device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
  2969. helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
  2970. }
  2971. template<typename T0, typename T04, typename T1, typename T14, typename args_t>
  2972. void kernel_mul_mv_t_t_4_disp(
  2973. args_t args,
  2974. device const char * src0,
  2975. device const char * src1,
  2976. device char * dst,
  2977. threadgroup char * shmem,
  2978. uint3 tgpig,
  2979. ushort tiisg,
  2980. ushort sgitg) {
  2981. switch (args.nr0) {
  2982. //case 1: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 1, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
  2983. case 2: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 2, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
  2984. //case 3: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 3, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
  2985. //case 4: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 4, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
  2986. };
  2987. }
  2988. template<typename T0, typename T04, typename T1, typename T14>
  2989. kernel void kernel_mul_mv_t_t_4(
  2990. constant ggml_metal_kargs_mul_mv & args,
  2991. device const char * src0,
  2992. device const char * src1,
  2993. device char * dst,
  2994. threadgroup char * shmem [[threadgroup(0)]],
  2995. uint3 tgpig[[threadgroup_position_in_grid]],
  2996. ushort tiisg[[thread_index_in_simdgroup]],
  2997. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  2998. kernel_mul_mv_t_t_4_disp<T0, T04, T1, T14, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
  2999. }
  3000. typedef decltype(kernel_mul_mv_t_t_4<half, half4, half, half4>) mul_mv_t_t_4;
  3001. template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<float, float4, float, float4>;
  3002. template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, float, float4>;
  3003. template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, half, half4>;
  3004. #if defined(GGML_METAL_HAS_BF16)
  3005. template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, float, float4>;
  3006. template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, bfloat, bfloat4>;
  3007. #endif
  3008. template<typename T0, typename T1, typename args_t>
  3009. void kernel_mul_mv_t_t_short_impl(
  3010. args_t args,
  3011. device const char * src0,
  3012. device const char * src1,
  3013. device char * dst,
  3014. uint3 tgpig,
  3015. ushort tiisg) {
  3016. const int r0 = tgpig.x*32 + tiisg;
  3017. const int r1 = tgpig.y;
  3018. const int im = tgpig.z;
  3019. if (r0 >= args.ne01) {
  3020. return;
  3021. }
  3022. const uint i12 = im%args.ne12;
  3023. const uint i13 = im/args.ne12;
  3024. const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
  3025. device const T0 * x = (device const T0 *) (src0 + offset0);
  3026. device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
  3027. const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
  3028. device const T1 * y = (device const T1 *) (src1 + offset1);
  3029. float res = 0.0f;
  3030. for (int i = 0; i < args.ne00; ++i) {
  3031. res += (float) x[i] * (float) y[i];
  3032. }
  3033. dst_f32[(uint64_t)r1*args.ne0 + r0] = res;
  3034. }
  3035. template<typename T0, typename T1>
  3036. kernel void kernel_mul_mv_t_t_short(
  3037. constant ggml_metal_kargs_mul_mv & args,
  3038. device const char * src0,
  3039. device const char * src1,
  3040. device char * dst,
  3041. uint3 tgpig[[threadgroup_position_in_grid]],
  3042. ushort tiisg[[thread_index_in_simdgroup]]) {
  3043. kernel_mul_mv_t_t_short_impl<T0, T1, constant ggml_metal_kargs_mul_mv &>(
  3044. args,
  3045. src0,
  3046. src1,
  3047. dst,
  3048. tgpig,
  3049. tiisg);
  3050. }
  3051. typedef decltype(kernel_mul_mv_t_t_short<half, half>) mul_mv_t_t_short_t;
  3052. template [[host_name("kernel_mul_mv_f32_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<float, float>;
  3053. template [[host_name("kernel_mul_mv_f16_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<half, float>;
  3054. template [[host_name("kernel_mul_mv_f16_f16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<half, half>;
  3055. #if defined(GGML_METAL_HAS_BF16)
  3056. template [[host_name("kernel_mul_mv_bf16_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, float>;
  3057. template [[host_name("kernel_mul_mv_bf16_bf16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, bfloat>;
  3058. #endif
  3059. constant bool FC_rope_is_imrope [[function_constant(FC_ROPE + 0)]];
  3060. static float rope_yarn_ramp(const float low, const float high, const int i0) {
  3061. const float y = (i0 / 2 - low) / max(0.001f, high - low);
  3062. return 1.0f - min(1.0f, max(0.0f, y));
  3063. }
  3064. // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
  3065. // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
  3066. static void rope_yarn(
  3067. float theta_extrap, float freq_scale, float corr_dims[2], int i0, float ext_factor, float mscale,
  3068. thread float * cos_theta, thread float * sin_theta) {
  3069. // Get n-d rotational scaling corrected for extrapolation
  3070. float theta_interp = freq_scale * theta_extrap;
  3071. float theta = theta_interp;
  3072. if (ext_factor != 0.0f) {
  3073. float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
  3074. theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
  3075. // Get n-d magnitude scaling corrected for interpolation
  3076. mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
  3077. }
  3078. *cos_theta = cos(theta) * mscale;
  3079. *sin_theta = sin(theta) * mscale;
  3080. }
  3081. // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
  3082. // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
  3083. static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
  3084. return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
  3085. }
  3086. static void rope_yarn_corr_dims(
  3087. int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
  3088. ) {
  3089. // start and end correction dims
  3090. dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
  3091. dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
  3092. }
  3093. template<typename T>
  3094. kernel void kernel_rope_norm(
  3095. constant ggml_metal_kargs_rope & args,
  3096. device const char * src0,
  3097. device const char * src1,
  3098. device const char * src2,
  3099. device char * dst,
  3100. ushort tiitg[[thread_index_in_threadgroup]],
  3101. ushort3 tptg [[threads_per_threadgroup]],
  3102. uint3 tgpig[[threadgroup_position_in_grid]]) {
  3103. const int i3 = tgpig[2];
  3104. const int i2 = tgpig[1];
  3105. const int i1 = tgpig[0];
  3106. float corr_dims[2];
  3107. rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
  3108. device const int32_t * pos = (device const int32_t *) src1;
  3109. const float theta_base = (float) pos[i2];
  3110. const float inv_ndims = -1.f/args.n_dims;
  3111. float cos_theta;
  3112. float sin_theta;
  3113. for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
  3114. if (i0 < args.n_dims) {
  3115. const int ic = i0/2;
  3116. const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
  3117. const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
  3118. rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
  3119. device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
  3120. device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
  3121. const float x0 = src[0];
  3122. const float x1 = src[1];
  3123. dst_data[0] = x0*cos_theta - x1*sin_theta;
  3124. dst_data[1] = x0*sin_theta + x1*cos_theta;
  3125. } else {
  3126. device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
  3127. device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
  3128. dst_data[0] = src[0];
  3129. dst_data[1] = src[1];
  3130. }
  3131. }
  3132. }
  3133. template<typename T>
  3134. kernel void kernel_rope_neox(
  3135. constant ggml_metal_kargs_rope & args,
  3136. device const char * src0,
  3137. device const char * src1,
  3138. device const char * src2,
  3139. device char * dst,
  3140. ushort tiitg[[thread_index_in_threadgroup]],
  3141. ushort3 tptg [[threads_per_threadgroup]],
  3142. uint3 tgpig[[threadgroup_position_in_grid]]) {
  3143. const int i3 = tgpig[2];
  3144. const int i2 = tgpig[1];
  3145. const int i1 = tgpig[0];
  3146. float corr_dims[2];
  3147. rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
  3148. device const int32_t * pos = (device const int32_t *) src1;
  3149. const float theta_base = (float) pos[i2];
  3150. const float inv_ndims = -1.f/args.n_dims;
  3151. float cos_theta;
  3152. float sin_theta;
  3153. for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
  3154. if (i0 < args.n_dims) {
  3155. const int ic = i0/2;
  3156. const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
  3157. const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
  3158. rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
  3159. device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
  3160. device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
  3161. const float x0 = src[0];
  3162. const float x1 = src[args.n_dims/2];
  3163. dst_data[0] = x0*cos_theta - x1*sin_theta;
  3164. dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
  3165. } else {
  3166. device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
  3167. device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
  3168. dst_data[0] = src[0];
  3169. dst_data[1] = src[1];
  3170. }
  3171. }
  3172. }
  3173. template<typename T>
  3174. kernel void kernel_rope_multi(
  3175. constant ggml_metal_kargs_rope & args,
  3176. device const char * src0,
  3177. device const char * src1,
  3178. device const char * src2,
  3179. device char * dst,
  3180. ushort tiitg[[thread_index_in_threadgroup]],
  3181. ushort3 tptg [[threads_per_threadgroup]],
  3182. uint3 tgpig[[threadgroup_position_in_grid]]) {
  3183. const int i3 = tgpig[2];
  3184. const int i2 = tgpig[1];
  3185. const int i1 = tgpig[0];
  3186. float corr_dims[2];
  3187. rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
  3188. device const int32_t * pos = (device const int32_t *) src1;
  3189. const float inv_ndims = -1.f/args.n_dims;
  3190. float cos_theta;
  3191. float sin_theta;
  3192. for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
  3193. if (i0 < args.n_dims) {
  3194. const int ic = i0/2;
  3195. // mrope theta calculations
  3196. // note: the rest is the same as kernel_rope_neox
  3197. const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3;
  3198. const int sec_w01 = args.sect_0 + args.sect_1; // end of section 1
  3199. const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
  3200. const int sector = ic % sect_dims;
  3201. float theta_base;
  3202. if (FC_rope_is_imrope) {
  3203. if (sector % 3 == 1 && sector < 3 * args.sect_1) { // h
  3204. theta_base = (float) pos[i2 + args.ne02 * 1];
  3205. } else if (sector % 3 == 2 && sector < 3 * args.sect_2) { // w
  3206. theta_base = (float) pos[i2 + args.ne02 * 2];
  3207. } else if (sector % 3 == 0 && sector < 3 * args.sect_0) { // t
  3208. theta_base = (float) pos[i2 + args.ne02 * 0];
  3209. } else { // e
  3210. theta_base = (float) pos[i2 + args.ne02 * 3];
  3211. }
  3212. } else {
  3213. if (sector < args.sect_0) {
  3214. theta_base = (float) pos[i2];
  3215. } else if (sector < sec_w01) {
  3216. theta_base = (float) pos[i2 + args.ne02 * 1];
  3217. } else if (sector < sec_w012) {
  3218. theta_base = (float) pos[i2 + args.ne02 * 2];
  3219. } else {
  3220. theta_base = (float) pos[i2 + args.ne02 * 3];
  3221. }
  3222. }
  3223. // end of mrope
  3224. const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
  3225. const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
  3226. rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
  3227. device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
  3228. device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
  3229. const float x0 = src[0];
  3230. const float x1 = src[args.n_dims/2];
  3231. dst_data[0] = x0*cos_theta - x1*sin_theta;
  3232. dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
  3233. } else {
  3234. device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
  3235. device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
  3236. dst_data[0] = src[0];
  3237. dst_data[1] = src[1];
  3238. }
  3239. }
  3240. }
  3241. template<typename T>
  3242. kernel void kernel_rope_vision(
  3243. constant ggml_metal_kargs_rope & args,
  3244. device const char * src0,
  3245. device const char * src1,
  3246. device const char * src2,
  3247. device char * dst,
  3248. ushort tiitg[[thread_index_in_threadgroup]],
  3249. ushort3 tptg [[threads_per_threadgroup]],
  3250. uint3 tgpig[[threadgroup_position_in_grid]]) {
  3251. const int i3 = tgpig[2];
  3252. const int i2 = tgpig[1];
  3253. const int i1 = tgpig[0];
  3254. float corr_dims[2];
  3255. rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
  3256. device const int32_t * pos = (device const int32_t *) src1;
  3257. const float inv_ndims = -1.f/args.n_dims;
  3258. float cos_theta;
  3259. float sin_theta;
  3260. for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
  3261. if (i0 < 2*args.n_dims) { // different from kernel_rope_multi
  3262. const int ic = i0/2;
  3263. // mrope theta calculations (only support 2 dimensions)
  3264. const int sect_dims = args.sect_0 + args.sect_1;
  3265. const int sector = ic % sect_dims;
  3266. float p;
  3267. float theta_base;
  3268. if (sector < args.sect_1) {
  3269. p = (float) sector;
  3270. theta_base = (float) pos[i2];
  3271. } else {
  3272. p = (float) sector - args.sect_0;
  3273. theta_base = (float) pos[i2 + args.ne02];
  3274. }
  3275. const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
  3276. // end of mrope
  3277. const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
  3278. rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
  3279. device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
  3280. device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
  3281. const float x0 = src[0];
  3282. const float x1 = src[args.n_dims]; // different from kernel_rope_multi
  3283. dst_data[0] = x0*cos_theta - x1*sin_theta;
  3284. dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi
  3285. } else {
  3286. device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
  3287. device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
  3288. dst_data[0] = src[0];
  3289. dst_data[1] = src[1];
  3290. }
  3291. }
  3292. }
  3293. typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
  3294. typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
  3295. typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
  3296. typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
  3297. template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
  3298. template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
  3299. template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
  3300. template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
  3301. template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
  3302. template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
  3303. template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
  3304. template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
  3305. typedef void (im2col_t)(
  3306. constant ggml_metal_kargs_im2col & args,
  3307. device const float * x,
  3308. device char * dst,
  3309. uint3 tgpig[[threadgroup_position_in_grid]],
  3310. uint3 tgpg[[threadgroups_per_grid]],
  3311. uint3 tpitg[[thread_position_in_threadgroup]],
  3312. uint3 ntg[[threads_per_threadgroup]]);
  3313. template <typename T>
  3314. kernel void kernel_im2col(
  3315. constant ggml_metal_kargs_im2col & args,
  3316. device const float * x,
  3317. device char * dst,
  3318. uint3 tgpig[[threadgroup_position_in_grid]],
  3319. uint3 tgpg[[threadgroups_per_grid]],
  3320. uint3 tpitg[[thread_position_in_threadgroup]],
  3321. uint3 ntg[[threads_per_threadgroup]]) {
  3322. // const int64_t IC = tgpg[0];
  3323. const int64_t OH = tgpg[1];
  3324. const int64_t OW = tgpg[2];
  3325. const int64_t KH = ntg[1];
  3326. const int64_t KW = ntg[2];
  3327. int64_t in = tpitg[0];
  3328. const int64_t ikh = tpitg[1];
  3329. const int64_t ikw = tpitg[2];
  3330. const int64_t iic = tgpig[0];
  3331. const int64_t ioh = tgpig[1];
  3332. const int64_t iow = tgpig[2];
  3333. const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0;
  3334. const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1;
  3335. int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
  3336. device T * pdst = (device T *) (dst);
  3337. if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
  3338. while (in < args.N) {
  3339. pdst[offset_dst] = 0.0f;
  3340. offset_dst += ntg[0]*args.CHW*OH*OW;
  3341. in += ntg[0];
  3342. }
  3343. } else {
  3344. int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;
  3345. while (in < args.N) {
  3346. pdst[offset_dst] = x[offset_src];
  3347. offset_dst += ntg[0]*args.CHW*OH*OW;
  3348. offset_src += ntg[0]*args.ofs0;
  3349. in += ntg[0];
  3350. }
  3351. }
  3352. }
  3353. template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
  3354. template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
  3355. // TODO: obolete -- remove
  3356. //typedef void (im2col_ext_t)(
  3357. // constant ggml_metal_kargs_im2col & args,
  3358. // device const float * x,
  3359. // device char * dst,
  3360. // uint3 tgpig[[threadgroup_position_in_grid]],
  3361. // uint3 tgpg[[threadgroups_per_grid]],
  3362. // uint3 tpitg[[thread_position_in_threadgroup]],
  3363. // uint3 ntg[[threads_per_threadgroup]]);
  3364. //
  3365. //template <typename T>
  3366. //kernel void kernel_im2col_ext(
  3367. // constant ggml_metal_kargs_im2col & args,
  3368. // device const float * x,
  3369. // device char * dst,
  3370. // uint3 tgpig[[threadgroup_position_in_grid]],
  3371. // uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
  3372. // uint3 tpitg[[thread_position_in_threadgroup]],
  3373. // uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
  3374. // const int64_t KHW = (int64_t)args.KHW;
  3375. //
  3376. // const int64_t d = tgpig[0] / args.CHW;
  3377. // const int64_t chw = tgpig[0] % args.CHW;
  3378. // const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
  3379. // const int64_t HW = tgpig[0] % KHW;
  3380. //
  3381. // const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
  3382. // if (tpitg_0 >= args.N) {
  3383. // return;
  3384. // }
  3385. //
  3386. // const int64_t tpitg_1 = HW / args.KW;
  3387. // const int64_t tpitg_2 = HW % args.KW;
  3388. //
  3389. // const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
  3390. // const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
  3391. //
  3392. // const int64_t offset_dst =
  3393. // (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
  3394. // (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
  3395. //
  3396. // device T * pdst = (device T *) (dst);
  3397. //
  3398. // if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
  3399. // pdst[offset_dst] = 0.0f;
  3400. // } else {
  3401. // const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
  3402. // pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
  3403. // }
  3404. //}
  3405. //
  3406. //template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
  3407. //template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
  3408. template <typename TK>
  3409. kernel void kernel_conv_2d(
  3410. constant ggml_metal_kargs_conv_2d & args,
  3411. device const char * weights,
  3412. device const char * src,
  3413. device char * dst,
  3414. uint3 tgpig[[threadgroup_position_in_grid]],
  3415. uint3 tgpg[[threadgroups_per_grid]],
  3416. uint3 tpitg[[thread_position_in_threadgroup]],
  3417. uint3 ntg[[threads_per_threadgroup]]) {
  3418. const uint threads_per_tg = ntg.x * ntg.y * ntg.z;
  3419. const uint tg_index = (tgpig.z * tgpg.y + tgpig.y) * tgpg.x + tgpig.x;
  3420. const uint local_thread = tpitg.z * (ntg.x * ntg.y) + tpitg.y * ntg.x + tpitg.x;
  3421. const uint thread_index = tg_index * threads_per_tg + local_thread;
  3422. const uint64_t total_threads = (uint64_t) threads_per_tg * tgpg.x * tgpg.y * tgpg.z;
  3423. const uint64_t total_outputs = (uint64_t) args.N * args.OC * args.OH * args.OW;
  3424. for (uint64_t index = thread_index; index < total_outputs; index += total_threads) {
  3425. uint64_t tmp = index;
  3426. const int32_t ow = tmp % args.OW; tmp /= args.OW;
  3427. const int32_t oh = tmp % args.OH; tmp /= args.OH;
  3428. const int32_t oc = tmp % args.OC; tmp /= args.OC;
  3429. const int32_t n = tmp;
  3430. float acc = 0.0f;
  3431. const int32_t base_x = ow*args.s0 - args.p0;
  3432. const int32_t base_y = oh*args.s1 - args.p1;
  3433. int32_t ky_start = 0;
  3434. if (base_y < 0) {
  3435. ky_start = (-base_y + args.d1 - 1)/args.d1;
  3436. }
  3437. int32_t ky_end = args.KH;
  3438. const int32_t y_max = args.IH - 1 - base_y;
  3439. if (y_max < 0) {
  3440. ky_end = ky_start;
  3441. } else if (base_y + (args.KH - 1)*args.d1 >= args.IH) {
  3442. ky_end = min(ky_end, y_max/args.d1 + 1);
  3443. }
  3444. int32_t kx_start = 0;
  3445. if (base_x < 0) {
  3446. kx_start = (-base_x + args.d0 - 1)/args.d0;
  3447. }
  3448. int32_t kx_end = args.KW;
  3449. const int32_t x_max = args.IW - 1 - base_x;
  3450. if (x_max < 0) {
  3451. kx_end = kx_start;
  3452. } else if (base_x + (args.KW - 1)*args.d0 >= args.IW) {
  3453. kx_end = min(kx_end, x_max/args.d0 + 1);
  3454. }
  3455. if (ky_start < ky_end && kx_start < kx_end) {
  3456. const uint64_t src_base_n = (uint64_t) n * args.nb13;
  3457. const uint64_t w_base_oc = (uint64_t) oc * args.nb03;
  3458. for (int32_t ic = 0; ic < args.IC; ++ic) {
  3459. const uint64_t src_base_nc = src_base_n + (uint64_t) ic * args.nb12;
  3460. const uint64_t w_base_ocic = w_base_oc + (uint64_t) ic * args.nb02;
  3461. for (int32_t ky = ky_start; ky < ky_end; ++ky) {
  3462. const int32_t iy = base_y + ky*args.d1;
  3463. const uint64_t src_base_row = src_base_nc + (uint64_t) iy * args.nb11;
  3464. const uint64_t w_base_row = w_base_ocic + (uint64_t) ky * args.nb01;
  3465. for (int32_t kx = kx_start; kx < kx_end; ++kx) {
  3466. const int32_t ix = base_x + kx*args.d0;
  3467. const uint64_t src_offs = src_base_row + (uint64_t) ix * args.nb10;
  3468. const uint64_t w_offs = w_base_row + (uint64_t) kx * args.nb00;
  3469. const float x = *(device const float *)(src + src_offs);
  3470. const float w = (float) (*(device const TK *)(weights + w_offs));
  3471. acc += x * w;
  3472. }
  3473. }
  3474. }
  3475. }
  3476. const uint64_t dst_offs =
  3477. (uint64_t) n * args.nb3 +
  3478. (uint64_t) oc * args.nb2 +
  3479. (uint64_t) oh * args.nb1 +
  3480. (uint64_t) ow * args.nb0;
  3481. *(device float *)(dst + dst_offs) = acc;
  3482. }
  3483. }
  3484. template [[host_name("kernel_conv_2d_f32_f32")]]
  3485. kernel void kernel_conv_2d<float>(
  3486. constant ggml_metal_kargs_conv_2d & args,
  3487. device const char * weights,
  3488. device const char * src,
  3489. device char * dst,
  3490. uint3 tgpig[[threadgroup_position_in_grid]],
  3491. uint3 tgpg[[threadgroups_per_grid]],
  3492. uint3 tpitg[[thread_position_in_threadgroup]],
  3493. uint3 ntg[[threads_per_threadgroup]]);
  3494. template [[host_name("kernel_conv_2d_f16_f32")]]
  3495. kernel void kernel_conv_2d<half>(
  3496. constant ggml_metal_kargs_conv_2d & args,
  3497. device const char * weights,
  3498. device const char * src,
  3499. device char * dst,
  3500. uint3 tgpig[[threadgroup_position_in_grid]],
  3501. uint3 tgpg[[threadgroups_per_grid]],
  3502. uint3 tpitg[[thread_position_in_threadgroup]],
  3503. uint3 ntg[[threads_per_threadgroup]]);
  3504. typedef void (conv_transpose_1d_t)(
  3505. constant ggml_metal_kargs_conv_transpose_1d & args,
  3506. device const float * src0,
  3507. device const float * src1,
  3508. device char * dst,
  3509. uint3 tgpig[[threadgroup_position_in_grid]],
  3510. uint3 tgpg[[threadgroups_per_grid]]);
  3511. template <typename T>
  3512. kernel void kernel_conv_transpose_1d(
  3513. constant ggml_metal_kargs_conv_transpose_1d & args,
  3514. device const T * src0,
  3515. device const float * src1,
  3516. device char * dst,
  3517. uint3 tgpig[[threadgroup_position_in_grid]],
  3518. uint3 tgpg[[threadgroups_per_grid]]) {
  3519. float v = 0.0f;
  3520. for (int64_t c = 0; c < args.IC; c++) {
  3521. const int32_t kernel_offset = c * tgpg[1] * args.K + args.K * tgpig[1];
  3522. const int32_t input_offset = c * args.IL;
  3523. for (int64_t i = 0; i < args.IL; i++) {
  3524. if (tgpig[0] >= i * args.s0 && tgpig[0] < i * args.s0 + args.K) {
  3525. v += src0[kernel_offset + tgpig[0] - i * args.s0] * src1[input_offset + i];
  3526. }
  3527. }
  3528. }
  3529. device float * dst_ptr = (device float *) (dst + tgpig[0] * args.nb0 + tgpig[1] * args.nb1);
  3530. dst_ptr[0] = v;
  3531. }
  3532. template [[host_name("kernel_conv_transpose_1d_f32_f32")]]
  3533. kernel void kernel_conv_transpose_1d<float>(
  3534. constant ggml_metal_kargs_conv_transpose_1d & args,
  3535. device const float * src0,
  3536. device const float * src1,
  3537. device char * dst,
  3538. uint3 tgpig[[threadgroup_position_in_grid]],
  3539. uint3 tgpg[[threadgroups_per_grid]]);
  3540. template [[host_name("kernel_conv_transpose_1d_f16_f32")]]
  3541. kernel void kernel_conv_transpose_1d<half>(
  3542. constant ggml_metal_kargs_conv_transpose_1d & args,
  3543. device const half * src0,
  3544. device const float * src1,
  3545. device char * dst,
  3546. uint3 tgpig[[threadgroup_position_in_grid]],
  3547. uint3 tgpg[[threadgroups_per_grid]]);
  3548. typedef void (conv_transpose_2d_t)(
  3549. constant ggml_metal_kargs_conv_transpose_2d & args,
  3550. device const float * src0,
  3551. device const float * src1,
  3552. device char * dst,
  3553. uint3 tgpig[[threadgroup_position_in_grid]],
  3554. uint3 tgpg[[threadgroups_per_grid]]);
  3555. template <typename T>
  3556. kernel void kernel_conv_transpose_2d(
  3557. constant ggml_metal_kargs_conv_transpose_2d & args,
  3558. device const T * src0,
  3559. device const float * src1,
  3560. device char * dst,
  3561. threadgroup float * shared_sum [[threadgroup(0)]],
  3562. uint3 tgpig[[threadgroup_position_in_grid]],
  3563. uint3 tpitg[[thread_position_in_threadgroup]],
  3564. uint3 ntg[[threads_per_threadgroup]]) {
  3565. const int64_t out_x = tgpig[0];
  3566. const int64_t out_y = tgpig[1];
  3567. const int64_t out_c = tgpig[2];
  3568. const int64_t kw = tpitg[0];
  3569. const int64_t kh = tpitg[1];
  3570. float v = 0.0f;
  3571. for (int64_t in_c = 0; in_c < args.IC; in_c++) {
  3572. int64_t in_y = out_y - kh;
  3573. if (in_y < 0 || in_y % args.s0) continue;
  3574. in_y /= args.s0;
  3575. if (in_y >= args.IH) continue;
  3576. int64_t in_x = out_x - kw;
  3577. if (in_x < 0 || in_x % args.s0) continue;
  3578. in_x /= args.s0;
  3579. if (in_x >= args.IW) continue;
  3580. const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x;
  3581. const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw;
  3582. v += (float)src0[kernel_idx] * src1[input_idx];
  3583. }
  3584. const uint tid = tpitg.y * ntg.x + tpitg.x;
  3585. shared_sum[tid] = v;
  3586. threadgroup_barrier(mem_flags::mem_threadgroup);
  3587. if (tid == 0) {
  3588. float total = 0.0f;
  3589. const uint num_threads = ntg.x * ntg.y;
  3590. for (uint i = 0; i < num_threads; i++) {
  3591. total += shared_sum[i];
  3592. }
  3593. device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2);
  3594. dst_ptr[0] = total;
  3595. }
  3596. }
  3597. template [[host_name("kernel_conv_transpose_2d_f32_f32")]]
  3598. kernel void kernel_conv_transpose_2d<float>(
  3599. constant ggml_metal_kargs_conv_transpose_2d & args,
  3600. device const float * src0,
  3601. device const float * src1,
  3602. device char * dst,
  3603. threadgroup float * shared_sum [[threadgroup(0)]],
  3604. uint3 tgpig[[threadgroup_position_in_grid]],
  3605. uint3 tpitg[[thread_position_in_threadgroup]],
  3606. uint3 ntg[[threads_per_threadgroup]]);
  3607. template [[host_name("kernel_conv_transpose_2d_f16_f32")]]
  3608. kernel void kernel_conv_transpose_2d<half>(
  3609. constant ggml_metal_kargs_conv_transpose_2d & args,
  3610. device const half * src0,
  3611. device const float * src1,
  3612. device char * dst,
  3613. threadgroup float * shared_sum [[threadgroup(0)]],
  3614. uint3 tgpig[[threadgroup_position_in_grid]],
  3615. uint3 tpitg[[thread_position_in_threadgroup]],
  3616. uint3 ntg[[threads_per_threadgroup]]);
  3617. kernel void kernel_upscale_f32(
  3618. constant ggml_metal_kargs_upscale & args,
  3619. device const char * src0,
  3620. device char * dst,
  3621. uint3 tgpig[[threadgroup_position_in_grid]],
  3622. uint3 tpitg[[thread_position_in_threadgroup]],
  3623. uint3 ntg[[threads_per_threadgroup]]) {
  3624. const int64_t i3 = tgpig.z;
  3625. const int64_t i2 = tgpig.y;
  3626. const int64_t i1 = tgpig.x;
  3627. const int64_t i03 = i3/args.sf3;
  3628. const int64_t i02 = i2/args.sf2;
  3629. const int64_t i01 = i1/args.sf1;
  3630. for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
  3631. const int64_t i00 = i0/args.sf0;
  3632. device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
  3633. device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
  3634. dst_ptr[0] = src0_ptr[0];
  3635. }
  3636. }
  3637. kernel void kernel_pad_f32(
  3638. constant ggml_metal_kargs_pad & args,
  3639. device const char * src0,
  3640. device char * dst,
  3641. uint3 tgpig[[threadgroup_position_in_grid]],
  3642. uint3 tpitg[[thread_position_in_threadgroup]],
  3643. uint3 ntg[[threads_per_threadgroup]]) {
  3644. const int64_t i3 = tgpig.z;
  3645. const int64_t i2 = tgpig.y;
  3646. const int64_t i1 = tgpig.x;
  3647. const int64_t i03 = i3;
  3648. const int64_t i02 = i2;
  3649. const int64_t i01 = i1;
  3650. device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
  3651. device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
  3652. if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
  3653. for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
  3654. if (i0 < args.ne00) {
  3655. dst_ptr[i0] = src0_ptr[i0];
  3656. } else {
  3657. dst_ptr[i0] = 0.0f;
  3658. }
  3659. }
  3660. return;
  3661. }
  3662. for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
  3663. dst_ptr[i0] = 0.0f;
  3664. }
  3665. }
  3666. kernel void kernel_pad_reflect_1d_f32(
  3667. constant ggml_metal_kargs_pad_reflect_1d & args,
  3668. device const char * src0,
  3669. device char * dst,
  3670. uint3 tgpig[[threadgroup_position_in_grid]],
  3671. uint3 tgpg[[threadgroups_per_grid]],
  3672. uint3 tpitg[[thread_position_in_threadgroup]],
  3673. uint3 ntg[[threads_per_threadgroup]]) {
  3674. const int64_t i3 = tgpig.z;
  3675. const int64_t i2 = tgpig.y;
  3676. const int64_t i1 = tgpig.x;
  3677. const int64_t i03 = i3;
  3678. const int64_t i02 = i2;
  3679. const int64_t i01 = i1;
  3680. device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
  3681. device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
  3682. if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
  3683. for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
  3684. if (i0 < args.p0) {
  3685. dst_ptr[i0] = src0_ptr[args.p0 - i0];
  3686. } else if (i0 < args.ne0 - args.p1) {
  3687. dst_ptr[i0] = src0_ptr[i0 - args.p0];
  3688. } else {
  3689. dst_ptr[i0] = src0_ptr[(args.ne0 - args.p1 - args.p0) - (args.p1 + 1 - (args.ne0 - i0)) - 1];
  3690. }
  3691. }
  3692. }
  3693. }
  3694. kernel void kernel_arange_f32(
  3695. constant ggml_metal_kargs_arange & args,
  3696. device char * dst,
  3697. uint3 tgpig[[threadgroup_position_in_grid]],
  3698. uint3 tpitg[[thread_position_in_threadgroup]],
  3699. uint3 ntg[[threads_per_threadgroup]]) {
  3700. device float * dst_ptr = (device float *) dst;
  3701. for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
  3702. dst_ptr[i0] = args.start + args.step * i0;
  3703. }
  3704. }
  3705. kernel void kernel_timestep_embedding_f32(
  3706. constant ggml_metal_kargs_timestep_embedding & args,
  3707. device const char * src0,
  3708. device char * dst,
  3709. uint3 tgpig[[threadgroup_position_in_grid]],
  3710. uint3 tpitg[[thread_position_in_threadgroup]],
  3711. uint3 ntg[[threads_per_threadgroup]]) {
  3712. int i = tgpig.x;
  3713. device float * embed_data = (device float *)(dst + i*args.nb1);
  3714. int half_ = args.dim / 2;
  3715. for (int j = tpitg.x; j < half_; j += ntg.x) {
  3716. float timestep = ((device float *)src0)[i];
  3717. float freq = (float)exp(-log((float)args.max_period) * j / half_);
  3718. float arg = timestep * freq;
  3719. embed_data[j ] = cos(arg);
  3720. embed_data[j + half_] = sin(arg);
  3721. }
  3722. if (args.dim % 2 != 0 && tpitg.x == 0) {
  3723. embed_data[2 * half_] = 0.f;
  3724. }
  3725. }
  3726. // bitonic sort implementation following the CUDA kernels as reference
  3727. typedef void (argsort_t)(
  3728. constant ggml_metal_kargs_argsort & args,
  3729. device const char * src0,
  3730. device int32_t * dst,
  3731. threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
  3732. uint3 tgpig[[threadgroup_position_in_grid]],
  3733. ushort3 tpitg[[thread_position_in_threadgroup]],
  3734. ushort3 ntg[[threads_per_threadgroup]]);
  3735. template<ggml_sort_order order>
  3736. kernel void kernel_argsort_f32_i32(
  3737. constant ggml_metal_kargs_argsort & args,
  3738. device const char * src0,
  3739. device int32_t * dst,
  3740. threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
  3741. uint3 tgpig[[threadgroup_position_in_grid]],
  3742. ushort3 tpitg[[thread_position_in_threadgroup]],
  3743. ushort3 ntg[[threads_per_threadgroup]]) {
  3744. // bitonic sort
  3745. const int col = tpitg[0];
  3746. const int ib = tgpig[0] / args.ne01;
  3747. const int i00 = ib*ntg.x;
  3748. const int i01 = tgpig[0] % args.ne01;
  3749. const int i02 = tgpig[1];
  3750. const int i03 = tgpig[2];
  3751. device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
  3752. // initialize indices
  3753. shmem_i32[col] = i00 + col;
  3754. threadgroup_barrier(mem_flags::mem_threadgroup);
  3755. for (int k = 2; k <= ntg.x; k *= 2) {
  3756. for (int j = k / 2; j > 0; j /= 2) {
  3757. int ixj = col ^ j;
  3758. if (ixj > col) {
  3759. if ((col & k) == 0) {
  3760. if (shmem_i32[col] >= args.ne00 ||
  3761. (shmem_i32[ixj] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
  3762. src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]] :
  3763. src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]]))
  3764. ) {
  3765. SWAP(shmem_i32[col], shmem_i32[ixj]);
  3766. }
  3767. } else {
  3768. if (shmem_i32[ixj] >= args.ne00 ||
  3769. (shmem_i32[col] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
  3770. src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]] :
  3771. src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]]))
  3772. ) {
  3773. SWAP(shmem_i32[col], shmem_i32[ixj]);
  3774. }
  3775. }
  3776. }
  3777. threadgroup_barrier(mem_flags::mem_threadgroup);
  3778. }
  3779. }
  3780. const int64_t i0 = ib*args.top_k;
  3781. // copy the result to dst without the padding
  3782. if (i0 + col < args.ne0 && col < args.top_k) {
  3783. dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03;
  3784. dst[col] = shmem_i32[col];
  3785. }
  3786. }
  3787. template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
  3788. template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
  3789. typedef void (argsort_merge_t)(
  3790. constant ggml_metal_kargs_argsort_merge & args,
  3791. device const char * src0,
  3792. device const int32_t * tmp,
  3793. device int32_t * dst,
  3794. uint3 tgpig[[threadgroup_position_in_grid]],
  3795. ushort3 tpitg[[thread_position_in_threadgroup]],
  3796. ushort3 ntg[[threads_per_threadgroup]]);
  3797. template<ggml_sort_order order>
  3798. kernel void kernel_argsort_merge_f32_i32(
  3799. constant ggml_metal_kargs_argsort_merge & args,
  3800. device const char * src0,
  3801. device const int32_t * tmp,
  3802. device int32_t * dst,
  3803. uint3 tgpig[[threadgroup_position_in_grid]],
  3804. ushort3 tpitg[[thread_position_in_threadgroup]],
  3805. ushort3 ntg[[threads_per_threadgroup]]) {
  3806. const int im = tgpig[0] / args.ne01;
  3807. const int i01 = tgpig[0] % args.ne01;
  3808. const int i02 = tgpig[1];
  3809. const int i03 = tgpig[2];
  3810. const int start = im * (2 * args.len);
  3811. const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start)));
  3812. const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len)));
  3813. const int total = len0 + len1;
  3814. device const int32_t * tmp0 = tmp + start
  3815. + i01*args.ne0
  3816. + i02*args.ne0*args.ne01
  3817. + i03*args.ne0*args.ne01*args.ne02;
  3818. device const int32_t * tmp1 = tmp0 + args.len;
  3819. dst += start
  3820. + i01*args.top_k
  3821. + i02*args.top_k*args.ne01
  3822. + i03*args.top_k*args.ne01*args.ne02;
  3823. device const float * src0_row = (device const float *)(src0
  3824. + args.nb01*i01
  3825. + args.nb02*i02
  3826. + args.nb03*i03);
  3827. if (total == 0) {
  3828. return;
  3829. }
  3830. const int chunk = (total + ntg.x - 1) / ntg.x;
  3831. const int k0 = tpitg.x * chunk;
  3832. const int k1 = MIN(MIN(k0 + chunk, total), args.top_k);
  3833. if (k0 >= args.top_k) {
  3834. return;
  3835. }
  3836. if (k0 >= total) {
  3837. return;
  3838. }
  3839. int low = k0 > len1 ? k0 - len1 : 0;
  3840. int high = MIN(k0, len0);
  3841. // binary-search partition (i, j) such that i + j = k
  3842. while (low < high) {
  3843. const int mid = (low + high) >> 1;
  3844. const int32_t idx0 = tmp0[mid];
  3845. const int32_t idx1 = tmp1[k0 - mid - 1];
  3846. const float val0 = src0_row[idx0];
  3847. const float val1 = src0_row[idx1];
  3848. bool take_left;
  3849. if (order == GGML_SORT_ORDER_ASC) {
  3850. take_left = (val0 <= val1);
  3851. } else {
  3852. take_left = (val0 >= val1);
  3853. }
  3854. if (take_left) {
  3855. low = mid + 1;
  3856. } else {
  3857. high = mid;
  3858. }
  3859. }
  3860. int i = low;
  3861. int j = k0 - i;
  3862. // keep the merge fronts into registers
  3863. int32_t idx0 = 0;
  3864. float val0 = 0.0f;
  3865. if (i < len0) {
  3866. idx0 = tmp0[i];
  3867. val0 = src0_row[idx0];
  3868. }
  3869. int32_t idx1 = 0;
  3870. float val1 = 0.0f;
  3871. if (j < len1) {
  3872. idx1 = tmp1[j];
  3873. val1 = src0_row[idx1];
  3874. }
  3875. for (int k = k0; k < k1; ++k) {
  3876. int32_t out_idx;
  3877. if (i >= len0) {
  3878. while (k < k1) {
  3879. dst[k++] = tmp1[j++];
  3880. }
  3881. break;
  3882. } else if (j >= len1) {
  3883. while (k < k1) {
  3884. dst[k++] = tmp0[i++];
  3885. }
  3886. break;
  3887. } else {
  3888. bool take_left;
  3889. if (order == GGML_SORT_ORDER_ASC) {
  3890. take_left = (val0 <= val1);
  3891. } else {
  3892. take_left = (val0 >= val1);
  3893. }
  3894. if (take_left) {
  3895. out_idx = idx0;
  3896. ++i;
  3897. if (i < len0) {
  3898. idx0 = tmp0[i];
  3899. val0 = src0_row[idx0];
  3900. }
  3901. } else {
  3902. out_idx = idx1;
  3903. ++j;
  3904. if (j < len1) {
  3905. idx1 = tmp1[j];
  3906. val1 = src0_row[idx1];
  3907. }
  3908. }
  3909. }
  3910. dst[k] = out_idx;
  3911. }
  3912. }
  3913. template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
  3914. template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
  3915. kernel void kernel_leaky_relu_f32(
  3916. constant ggml_metal_kargs_leaky_relu & args,
  3917. device const float * src0,
  3918. device float * dst,
  3919. uint tpig[[thread_position_in_grid]]) {
  3920. const float x = src0[tpig];
  3921. dst[tpig] = x > 0.0f ? x : x * args.slope;
  3922. }
  3923. kernel void kernel_leaky_relu_f32_4(
  3924. constant ggml_metal_kargs_leaky_relu & args,
  3925. device const float4 * src0,
  3926. device float4 * dst,
  3927. uint tpig[[thread_position_in_grid]]) {
  3928. const float4 x = src0[tpig];
  3929. dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope);
  3930. }
  3931. constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]];
  3932. constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]];
  3933. // pad the last chunk of C elements of k and v into a an extra pad buffer
  3934. kernel void kernel_flash_attn_ext_pad(
  3935. constant ggml_metal_kargs_flash_attn_ext_pad & args,
  3936. device const char * k,
  3937. device const char * v,
  3938. device const char * mask,
  3939. device char * dst,
  3940. uint3 tgpig[[threadgroup_position_in_grid]],
  3941. ushort tiitg[[thread_index_in_threadgroup]],
  3942. ushort3 ntg[[threads_per_threadgroup]]) {
  3943. const int32_t C = FC_flash_attn_ext_pad_ncpsg;
  3944. device char * k_pad = dst;
  3945. device char * v_pad = k_pad + args.nb11*C*args.ne_12_2*args.ne_12_3;
  3946. device char * mask_pad = v_pad + args.nb21*C*args.ne_12_2*args.ne_12_3;
  3947. const int32_t icp = args.ne11 % C;
  3948. const int32_t ic0 = args.ne11 - icp;
  3949. const int32_t i1 = tgpig[0];
  3950. const int32_t i2 = tgpig[1];
  3951. const int32_t i3 = tgpig[2];
  3952. if (i2 < args.ne_12_2 && i3 < args.ne_12_3) {
  3953. device const char * k_src = k + args.nb11*(ic0 + i1) + args.nb12*i2 + args.nb13*i3;
  3954. device const char * v_src = v + args.nb21*(ic0 + i1) + args.nb22*i2 + args.nb23*i3;
  3955. device char * k_dst = k_pad + args.nb11*i1 + args.nb11*C*i2 + args.nb11*C*args.ne_12_2*i3;
  3956. device char * v_dst = v_pad + args.nb21*i1 + args.nb21*C*i2 + args.nb21*C*args.ne_12_2*i3;
  3957. if (i1 >= icp) {
  3958. // here it is not important the exact value that will be used as we rely on masking out the scores in the attention
  3959. for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) {
  3960. k_dst[i] = 0;
  3961. }
  3962. for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) {
  3963. v_dst[i] = 0;
  3964. }
  3965. } else {
  3966. for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) {
  3967. k_dst[i] = k_src[i];
  3968. }
  3969. for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) {
  3970. v_dst[i] = v_src[i];
  3971. }
  3972. }
  3973. }
  3974. if (FC_flash_attn_ext_pad_has_mask) {
  3975. if (i2 < args.ne32 && i3 < args.ne33) {
  3976. for (int ib = i1; ib < args.ne31; ib += C) {
  3977. device const half * mask_src = (device const half *)(mask + args.nb31*ib + args.nb32*i2 + args.nb33*i3) + ic0;
  3978. device half * mask_dst = (device half *)(mask_pad) + C*ib + C*args.ne31*i2 + C*args.ne31*args.ne32*i3;
  3979. for (int i = tiitg; i < C; i += ntg.x) {
  3980. if (i >= icp) {
  3981. mask_dst[i] = -MAXHALF;
  3982. } else {
  3983. mask_dst[i] = mask_src[i];
  3984. }
  3985. }
  3986. }
  3987. }
  3988. }
  3989. }
  3990. constant int32_t FC_flash_attn_ext_blk_nqptg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 24)]];
  3991. constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 25)]];
  3992. // scan the blocks of the mask that are not masked
  3993. // 0 - masked (i.e. full of -INF, skip)
  3994. // 1 - not masked (i.e. at least one element of the mask is not -INF)
  3995. kernel void kernel_flash_attn_ext_blk(
  3996. constant ggml_metal_kargs_flash_attn_ext_blk & args,
  3997. device const char * mask,
  3998. device char * dst,
  3999. uint3 tgpig[[threadgroup_position_in_grid]],
  4000. ushort tiisg[[thread_index_in_simdgroup]]) {
  4001. // block size C x Q
  4002. const int32_t Q = FC_flash_attn_ext_blk_nqptg;
  4003. const int32_t C = FC_flash_attn_ext_blk_ncpsg;
  4004. constexpr short NW = N_SIMDWIDTH;
  4005. const int32_t i3 = tgpig[2]/args.ne32;
  4006. const int32_t i2 = tgpig[2]%args.ne32;
  4007. const int32_t i1 = tgpig[1];
  4008. const int32_t i0 = tgpig[0];
  4009. char res = i0*C + C > args.ne30 ? 1 : 0;
  4010. device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg;
  4011. // fast route
  4012. if (res == 0) {
  4013. if (simd_max(*mask_src) > -MAXHALF/2) {
  4014. res = 1;
  4015. }
  4016. }
  4017. // detailed check of the elements of the block
  4018. if ((C > NW || Q > 1) && res == 0) {
  4019. half m = -MAXHALF;
  4020. FOR_UNROLL (short j = 0; j < Q; ++j) {
  4021. FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) {
  4022. m = max(m, mask_src[ii*NW]);
  4023. }
  4024. mask_src += args.nb31/2;
  4025. }
  4026. if (simd_max(m) > -MAXHALF/2) {
  4027. res = 1;
  4028. }
  4029. }
  4030. const int32_t nblk1 = ((args.ne01 + Q - 1)/Q);
  4031. const int32_t nblk0 = ((args.ne30 + C - 1)/C);
  4032. if (tiisg == 0) {
  4033. dst[((i3*args.ne32 + i2)*nblk1 + i1)*nblk0 + i0] = res;
  4034. }
  4035. }
  4036. constant bool FC_flash_attn_ext_has_mask [[function_constant(FC_FLASH_ATTN_EXT + 0)]];
  4037. constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]];
  4038. constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT + 2)]];
  4039. constant bool FC_flash_attn_ext_has_scap [[function_constant(FC_FLASH_ATTN_EXT + 3)]];
  4040. constant bool FC_flash_attn_ext_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT + 4)]];
  4041. constant bool FC_flash_attn_ext_bc_mask [[function_constant(FC_FLASH_ATTN_EXT + 10)]];
  4042. //constant float FC_flash_attn_ext_scale [[function_constant(FC_FLASH_ATTN_EXT + 10)]];
  4043. //constant float FC_flash_attn_ext_max_bias [[function_constant(FC_FLASH_ATTN_EXT + 11)]];
  4044. //constant float FC_flash_attn_ext_logit_softcap [[function_constant(FC_FLASH_ATTN_EXT + 12)]];
  4045. constant int32_t FC_flash_attn_ext_ns10 [[function_constant(FC_FLASH_ATTN_EXT + 20)]];
  4046. constant int32_t FC_flash_attn_ext_ns20 [[function_constant(FC_FLASH_ATTN_EXT + 21)]];
  4047. constant int32_t FC_flash_attn_ext_nsg [[function_constant(FC_FLASH_ATTN_EXT + 22)]];
  4048. // ref: https://arxiv.org/pdf/2307.08691.pdf
  4049. template<
  4050. typename q_t, // query types in shared memory
  4051. typename q4_t,
  4052. typename q8x8_t,
  4053. typename k_t, // key types in shared memory
  4054. typename k4x4_t,
  4055. typename k8x8_t,
  4056. typename v_t, // value types in shared memory
  4057. typename v4x4_t,
  4058. typename v8x8_t,
  4059. typename qk_t, // Q*K types
  4060. typename qk8x8_t,
  4061. typename s_t, // soft-max types
  4062. typename s2_t,
  4063. typename s8x8_t,
  4064. typename o_t, // attention accumulation types
  4065. typename o4_t,
  4066. typename o8x8_t,
  4067. typename kd4x4_t, // key type in device memory
  4068. short nl_k,
  4069. void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
  4070. typename vd4x4_t, // value type in device memory
  4071. short nl_v,
  4072. void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
  4073. short DK, // K head size
  4074. short DV, // V head size
  4075. short Q, // queries per threadgroup
  4076. short C, // cache items per threadgroup
  4077. short NSG> // number of simd groups
  4078. void kernel_flash_attn_ext_impl(
  4079. constant ggml_metal_kargs_flash_attn_ext & args,
  4080. device const char * q,
  4081. device const char * k,
  4082. device const char * v,
  4083. device const char * mask,
  4084. device const char * sinks,
  4085. device const char * pad,
  4086. device const char * blk,
  4087. device char * dst,
  4088. threadgroup half * shmem_f16,
  4089. uint3 tgpig,
  4090. ushort tiisg,
  4091. ushort sgitg) {
  4092. const ushort iq3 = tgpig[2];
  4093. const ushort iq2 = tgpig[1];
  4094. const ushort iq1 = tgpig[0]*Q;
  4095. #define NS10 (FC_flash_attn_ext_ns10)
  4096. #define NS20 (FC_flash_attn_ext_ns20)
  4097. // note: I had some concerns that using this instead of the ugly macros above was affecting performance
  4098. // need to re-check carefully and if no regressions are observerd - remove the macros
  4099. // the concerns is that maybe using const variables requires extra registers? but not sure if the compiler
  4100. // is clever enough to avoid this. unfortunately, using constexpr is not possible with FC
  4101. //const short NS10 = FC_flash_attn_ext_ns10;
  4102. //const short NS20 = FC_flash_attn_ext_ns20;
  4103. constexpr short KV = 8;
  4104. constexpr short DK4 = DK/4;
  4105. constexpr short DK8 = DK/8;
  4106. constexpr short DK16 = DK/16;
  4107. constexpr short DV4 = DV/4;
  4108. //constexpr short DV8 = DV/8;
  4109. constexpr short DV16 = DV/16;
  4110. constexpr short PV = PAD2(DV, 64);
  4111. constexpr short PV4 = PV/4;
  4112. constexpr short PV8 = PV/8;
  4113. //constexpr short PV16 = PV/16;
  4114. constexpr short NW = N_SIMDWIDTH;
  4115. constexpr short NQ = Q/NSG;
  4116. constexpr short SH = 2*C; // shared memory per simdgroup (s_t == float)
  4117. constexpr short TS = 2*SH;
  4118. constexpr short T = DK + 2*PV; // shared memory size per query in (half)
  4119. threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*T); // holds the query data
  4120. threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*T); // same as above but in q4_t
  4121. threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*T + Q*DK); // the result for all queries in 8x8 matrices (the O matrix from the paper)
  4122. threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*T + Q*DK);
  4123. threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + Q*T); // scratch buffer for attention, mask and diagonal matrix
  4124. threadgroup s2_t * ss2 = (threadgroup s2_t *) (shmem_f16 + Q*T); // same as above but in s2_t
  4125. threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // scratch buffer to load K in shared memory
  4126. threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // same as above but in k4x4_t
  4127. threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // scratch buffer to load V in shared memory
  4128. threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // same as above but in v4x4_t
  4129. // mask storage in shared mem
  4130. threadgroup half2 * sm2 = (threadgroup half2 *) (shmem_f16 + Q*T + 2*C);
  4131. // per-query mask pointers
  4132. device const half2 * pm2[NQ];
  4133. FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
  4134. const short j = jj*NSG + sgitg;
  4135. pm2[jj] = (device const half2 *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
  4136. }
  4137. {
  4138. const int32_t nblk1 = ((args.ne01 + Q - 1)/Q);
  4139. const int32_t nblk0 = ((args.ne11 + C - 1)/C);
  4140. blk += (((iq3%args.ne33)*args.ne32 + (iq2%args.ne32))*nblk1 + iq1/Q)*nblk0;
  4141. }
  4142. {
  4143. q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03;
  4144. const short ikv2 = iq2/(args.ne02/args.ne_12_2);
  4145. const short ikv3 = iq3/(args.ne03/args.ne_12_3);
  4146. k += ikv2*args.nb12 + ikv3*args.nb13;
  4147. v += ikv2*args.nb22 + ikv3*args.nb23;
  4148. }
  4149. // load heads from Q to shared memory
  4150. FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
  4151. const short j = jj*NSG + sgitg;
  4152. device const float4 * q4 = (device const float4 *) ((device const char *) q + j*args.nb01);
  4153. for (short i = tiisg; i < DK4; i += NW) {
  4154. if (iq1 + j < args.ne01) {
  4155. sq4[j*DK4 + i] = (q4_t) q4[i];
  4156. } else {
  4157. sq4[j*DK4 + i] = 0;
  4158. }
  4159. }
  4160. }
  4161. // zero out
  4162. FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
  4163. const short j = jj*NSG + sgitg;
  4164. for (short i = tiisg; i < DV4; i += NW) {
  4165. so4[j*PV4 + i] = 0;
  4166. }
  4167. for (short i = tiisg; i < SH; i += NW) {
  4168. ss[j*SH + i] = 0.0f;
  4169. }
  4170. }
  4171. threadgroup_barrier(mem_flags::mem_threadgroup);
  4172. float S[NQ] = { [0 ... NQ-1] = 0.0f };
  4173. {
  4174. float M[NQ] = { [0 ... NQ-1] = -FLT_MAX/2 };
  4175. float slope = 1.0f;
  4176. // ALiBi
  4177. if (FC_flash_attn_ext_has_bias) {
  4178. const short h = iq2;
  4179. const float base = h < args.n_head_log2 ? args.m0 : args.m1;
  4180. const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
  4181. slope = pow(base, exph);
  4182. }
  4183. // loop over the KV cache
  4184. // each simdgroup handles blocks of Q rows and C columns
  4185. for (int ic0 = 0; ; ++ic0) {
  4186. int ic = ic0*C;
  4187. if (ic >= args.ne11) {
  4188. break;
  4189. }
  4190. // the last partial chunk uses the pad buffer as source
  4191. if (FC_flash_attn_ext_has_kvpad && ic + C > args.ne11) {
  4192. k = pad;
  4193. v = k + args.nb11*C*args.ne_12_2*args.ne_12_3;
  4194. mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3;
  4195. const short ikv2 = iq2/(args.ne02/args.ne_12_2);
  4196. const short ikv3 = iq3/(args.ne03/args.ne_12_3);
  4197. k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C;
  4198. v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;
  4199. if (!FC_flash_attn_ext_has_mask) {
  4200. threadgroup half * sm = (threadgroup half *) (sm2);
  4201. FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
  4202. const short j = jj*NSG + sgitg;
  4203. for (short i = tiisg; i < C; i += NW) {
  4204. if (ic + i >= args.ne11) {
  4205. sm[2*j*SH + i] = -MAXHALF;
  4206. }
  4207. }
  4208. }
  4209. } else {
  4210. FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
  4211. const short j = jj*NSG + sgitg;
  4212. pm2[jj] = (device const half2 *) ((device const half *) mask +
  4213. (iq1 + j)*C +
  4214. (iq2%args.ne32)*(C*args.ne31) +
  4215. (iq3%args.ne33)*(C*args.ne31*args.ne32));
  4216. }
  4217. }
  4218. ic = 0;
  4219. }
  4220. // read the mask into shared mem
  4221. if (FC_flash_attn_ext_has_mask) {
  4222. if (blk[ic0] == 0) {
  4223. FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
  4224. pm2[jj] += NW;
  4225. }
  4226. continue;
  4227. }
  4228. FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
  4229. const short j = jj*NSG + sgitg;
  4230. if (FC_flash_attn_ext_bc_mask) {
  4231. sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
  4232. } else {
  4233. sm2[j*SH + tiisg] = pm2[jj][tiisg];
  4234. }
  4235. pm2[jj] += NW;
  4236. }
  4237. #if 0
  4238. // note: old -INF block optimization - obsoleted by pre-computing non-masked blocks
  4239. threadgroup_barrier(mem_flags::mem_threadgroup);
  4240. // used to detect blocks full of -INF
  4241. // skip only when the entire threadgroup is masked
  4242. half2 smax2(-MAXHALF/2, -MAXHALF/2);
  4243. FOR_UNROLL (short j = 0; j < Q; ++j) {
  4244. smax2 = max(smax2, sm2[j*SH + tiisg]);
  4245. }
  4246. smax2 = simd_max(smax2);
  4247. if (max(smax2[0], smax2[1]) <= -MAXHALF/2) {
  4248. // this barrier is important
  4249. threadgroup_barrier(mem_flags::mem_threadgroup);
  4250. continue;
  4251. }
  4252. #endif
  4253. }
  4254. // Q*K^T
  4255. // this is compile-time check, so it does not have runtime overhead
  4256. if (is_same<kd4x4_t, k4x4_t>::value) {
  4257. // we can read directly from global memory
  4258. device const k_t * pk = (device const k_t *) (k + ic*args.nb11);
  4259. threadgroup const q_t * pq = sq;
  4260. threadgroup s_t * ps = ss;
  4261. pk += sgitg*(8*NS10);
  4262. ps += sgitg*(8*1);
  4263. static_assert((C/8) % NSG == 0, "");
  4264. constexpr short NC = (C/8)/NSG;
  4265. // note: do not unroll for large heads
  4266. #pragma unroll (DK <= 64 ? NC : 1)
  4267. for (short cc = 0; cc < NC; ++cc) {
  4268. qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
  4269. if (DK % 16 != 0) {
  4270. k8x8_t mk;
  4271. q8x8_t mq;
  4272. FOR_UNROLL (short i = 0; i < DK8; ++i) {
  4273. simdgroup_barrier(mem_flags::mem_none);
  4274. simdgroup_load(mk, pk + 8*i, NS10, 0, true);
  4275. simdgroup_load(mq, pq + 8*i, DK);
  4276. simdgroup_barrier(mem_flags::mem_none);
  4277. simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
  4278. }
  4279. } else {
  4280. k8x8_t mk[2];
  4281. q8x8_t mq[2];
  4282. FOR_UNROLL (short i = 0; i < DK8/2; ++i) {
  4283. simdgroup_barrier(mem_flags::mem_none);
  4284. simdgroup_load(mq[0], pq + 0*8 + 16*i, DK);
  4285. simdgroup_load(mq[1], pq + 1*8 + 16*i, DK);
  4286. simdgroup_load(mk[0], pk + 0*8 + 16*i, NS10, 0, true);
  4287. simdgroup_load(mk[1], pk + 1*8 + 16*i, NS10, 0, true);
  4288. simdgroup_barrier(mem_flags::mem_none);
  4289. simdgroup_multiply_accumulate(mqk, mq[0], mk[0], mqk);
  4290. simdgroup_multiply_accumulate(mqk, mq[1], mk[1], mqk);
  4291. }
  4292. }
  4293. simdgroup_store(mqk, ps, SH, 0, false);
  4294. pk += 8*(NSG*NS10);
  4295. ps += 8*(NSG);
  4296. }
  4297. } else {
  4298. // TODO: this is the quantized K cache branch - not optimized yet
  4299. for (short ccc = 0; ccc < (C/8)/NSG; ++ccc) {
  4300. const short cc = ccc*NSG + sgitg;
  4301. const short tx = tiisg%4;
  4302. const short ty = tiisg/4;
  4303. qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
  4304. for (short ii = 0; ii < DK16; ii += 4) {
  4305. device const kd4x4_t * pk4x4 = (device const kd4x4_t *) (k + ((ic + 8*cc + ty)*args.nb11));
  4306. if (DK16%4 == 0) {
  4307. // the head is evenly divisible by 4*16 = 64, so no need for bound checks
  4308. {
  4309. k4x4_t tmp;
  4310. deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
  4311. sk4x4[4*ty + tx] = tmp;
  4312. }
  4313. simdgroup_barrier(mem_flags::mem_threadgroup);
  4314. FOR_UNROLL (short k = 0; k < 4; ++k) {
  4315. k8x8_t mk;
  4316. q8x8_t mq;
  4317. simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
  4318. simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
  4319. simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
  4320. simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
  4321. simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
  4322. simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
  4323. }
  4324. } else {
  4325. if (ii + tx < DK16) {
  4326. k4x4_t tmp;
  4327. deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
  4328. sk4x4[4*ty + tx] = tmp;
  4329. }
  4330. simdgroup_barrier(mem_flags::mem_threadgroup);
  4331. for (short k = 0; k < 4 && ii + k < DK16; ++k) {
  4332. k8x8_t mk;
  4333. q8x8_t mq;
  4334. simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
  4335. simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
  4336. simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
  4337. simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
  4338. simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
  4339. simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
  4340. }
  4341. }
  4342. }
  4343. simdgroup_store(mqk, ss + 8*cc, SH, 0, false);
  4344. }
  4345. }
  4346. threadgroup_barrier(mem_flags::mem_threadgroup);
  4347. // online softmax
  4348. FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
  4349. const short j = jj*NSG + sgitg;
  4350. const float m = M[jj];
  4351. // scale and apply the logitcap / mask
  4352. float2 s2 = ss2[j*SH/2 + tiisg]*args.scale;
  4353. if (FC_flash_attn_ext_has_scap) {
  4354. s2 = args.logit_softcap*precise::tanh(s2);
  4355. }
  4356. // mqk = mqk + slope*mask
  4357. if (FC_flash_attn_ext_has_bias) {
  4358. s2 += s2_t(sm2[j*SH + tiisg])*slope;
  4359. } else {
  4360. s2 += s2_t(sm2[j*SH + tiisg]);
  4361. }
  4362. M[jj] = simd_max(max(M[jj], max(s2[0], s2[1])));
  4363. const float ms = exp(m - M[jj]);
  4364. const float2 vs2 = exp(s2 - M[jj]);
  4365. S[jj] = S[jj]*ms + simd_sum(vs2[0] + vs2[1]);
  4366. // the P matrix from the paper (Q rows, C columns)
  4367. ss2[j*SH/2 + tiisg] = vs2;
  4368. if (DV4 % NW == 0) {
  4369. FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) {
  4370. const short i = ii*NW + tiisg;
  4371. so4[j*PV4 + i] *= ms;
  4372. }
  4373. } else {
  4374. for (short i = tiisg; i < DV4; i += NW) {
  4375. so4[j*PV4 + i] *= ms;
  4376. }
  4377. }
  4378. }
  4379. threadgroup_barrier(mem_flags::mem_threadgroup);
  4380. // O = O + (Q*K^T)*V
  4381. {
  4382. // we can read directly from global memory
  4383. if (is_same<vd4x4_t, v4x4_t>::value) {
  4384. static_assert(PV8 % NSG == 0, "");
  4385. constexpr short NO = PV8/NSG;
  4386. o8x8_t lo[NO];
  4387. {
  4388. auto sot = so + 8*sgitg;
  4389. FOR_UNROLL (short ii = 0; ii < NO; ++ii) {
  4390. simdgroup_load(lo[ii], sot, PV, 0, false);
  4391. sot += 8*NSG;
  4392. }
  4393. }
  4394. {
  4395. device const v_t * pv = (device const v_t *) (v + ic*args.nb21);
  4396. pv += 8*sgitg;
  4397. if (DV <= 64) {
  4398. FOR_UNROLL (short cc = 0; cc < C/8; ++cc) {
  4399. s8x8_t vs;
  4400. simdgroup_load(vs, ss + 8*cc, SH, 0, false);
  4401. FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) {
  4402. v8x8_t mv[2];
  4403. simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG, NS20, 0, false);
  4404. simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG, NS20, 0, false);
  4405. simdgroup_multiply_accumulate(lo[2*ii + 0], vs, mv[0], lo[2*ii + 0]);
  4406. simdgroup_multiply_accumulate(lo[2*ii + 1], vs, mv[1], lo[2*ii + 1]);
  4407. }
  4408. pv += 8*NS20;
  4409. }
  4410. } else {
  4411. FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) {
  4412. s8x8_t vs[2];
  4413. simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false);
  4414. simdgroup_load(vs[1], ss + 16*cc + 8, SH, 0, false);
  4415. FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) {
  4416. v8x8_t mv[4];
  4417. simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false);
  4418. simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false);
  4419. simdgroup_load(mv[2], pv + 0*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false);
  4420. simdgroup_load(mv[3], pv + 8*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false);
  4421. simdgroup_multiply_accumulate(lo[2*ii + 0], vs[0], mv[0], lo[2*ii + 0]);
  4422. simdgroup_multiply_accumulate(lo[2*ii + 1], vs[0], mv[1], lo[2*ii + 1]);
  4423. simdgroup_multiply_accumulate(lo[2*ii + 0], vs[1], mv[2], lo[2*ii + 0]);
  4424. simdgroup_multiply_accumulate(lo[2*ii + 1], vs[1], mv[3], lo[2*ii + 1]);
  4425. }
  4426. pv += 2*8*NS20;
  4427. }
  4428. }
  4429. }
  4430. {
  4431. auto sot = so + 8*sgitg;
  4432. FOR_UNROLL (short ii = 0; ii < NO; ++ii) {
  4433. simdgroup_store(lo[ii], sot, PV, 0, false);
  4434. sot += 8*NSG;
  4435. }
  4436. }
  4437. } else {
  4438. // TODO: this is the quantized V cache branch - not optimized yet
  4439. const short tx = tiisg%4;
  4440. const short ty = tiisg/4;
  4441. for (short cc = 0; cc < C/8; ++cc) {
  4442. s8x8_t vs;
  4443. simdgroup_load(vs, ss + 8*cc, SH, 0, false);
  4444. for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) {
  4445. device const vd4x4_t * pv4x4 = (device const vd4x4_t *) (v + ((ic + 8*cc + ty)*args.nb21));
  4446. if (DV16%4 == 0) {
  4447. // no need for bound checks
  4448. {
  4449. v4x4_t tmp;
  4450. deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
  4451. sv4x4[4*ty + tx] = tmp;
  4452. }
  4453. simdgroup_barrier(mem_flags::mem_threadgroup);
  4454. FOR_UNROLL (short k = 0; k < 4; ++k) {
  4455. v8x8_t mv[2];
  4456. o8x8_t lo[2];
  4457. simdgroup_load(mv[0], sv + 16*k + 0*8, 4*16, 0, false);
  4458. simdgroup_load(mv[1], sv + 16*k + 1*8, 4*16, 0, false);
  4459. simdgroup_load(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);
  4460. simdgroup_load(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);
  4461. simdgroup_multiply_accumulate(lo[0], vs, mv[0], lo[0]);
  4462. simdgroup_multiply_accumulate(lo[1], vs, mv[1], lo[1]);
  4463. simdgroup_store(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);
  4464. simdgroup_store(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);
  4465. }
  4466. } else {
  4467. if (ii + tx < DV16) {
  4468. v4x4_t tmp;
  4469. deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
  4470. sv4x4[4*ty + tx] = tmp;
  4471. }
  4472. simdgroup_barrier(mem_flags::mem_threadgroup);
  4473. for (short k = 0; k < 4 && ii + k < DV16; ++k) {
  4474. v8x8_t mv[2];
  4475. o8x8_t lo[2];
  4476. simdgroup_load(mv[0], sv + 16*k + 0*8, 4*16, 0, false);
  4477. simdgroup_load(mv[1], sv + 16*k + 1*8, 4*16, 0, false);
  4478. simdgroup_load(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);
  4479. simdgroup_load(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);
  4480. simdgroup_multiply_accumulate(lo[0], vs, mv[0], lo[0]);
  4481. simdgroup_multiply_accumulate(lo[1], vs, mv[1], lo[1]);
  4482. simdgroup_store(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);
  4483. simdgroup_store(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);
  4484. }
  4485. }
  4486. }
  4487. }
  4488. }
  4489. }
  4490. threadgroup_barrier(mem_flags::mem_threadgroup);
  4491. }
  4492. if (FC_flash_attn_ext_has_sinks) {
  4493. FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
  4494. const short j = jj*NSG + sgitg;
  4495. const float m = M[jj];
  4496. const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
  4497. M[jj] = simd_max(max(M[jj], s));
  4498. const float ms = exp(m - M[jj]);
  4499. const float vs = exp(s - M[jj]);
  4500. S[jj] = S[jj]*ms + simd_sum(vs);
  4501. for (short i = tiisg; i < DV4; i += NW) {
  4502. so4[j*PV4 + i] *= ms;
  4503. }
  4504. }
  4505. }
  4506. }
  4507. // store to global memory
  4508. for (short jj = 0; jj < NQ; ++jj) {
  4509. const short j = jj*NSG + sgitg;
  4510. if (iq1 + j >= args.ne01) {
  4511. break;
  4512. }
  4513. device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4;
  4514. const float scale = S[jj] == 0.0 ? 0.0f : 1.0f/S[jj];
  4515. if (DV4 % NW == 0) {
  4516. FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) {
  4517. const short i = ii*NW + tiisg;
  4518. dst4[i] = (float4) so4[j*PV4 + i]*scale;
  4519. }
  4520. } else {
  4521. for (short i = tiisg; i < DV4; i += NW) {
  4522. dst4[i] = (float4) so4[j*PV4 + i]*scale;
  4523. }
  4524. }
  4525. }
  4526. #undef NS10
  4527. #undef NS20
  4528. }
  4529. template<
  4530. typename q_t, // query types in shared memory
  4531. typename q4_t,
  4532. typename q8x8_t,
  4533. typename k_t, // key types in shared memory
  4534. typename k4x4_t,
  4535. typename k8x8_t,
  4536. typename v_t, // value types in shared memory
  4537. typename v4x4_t,
  4538. typename v8x8_t,
  4539. typename qk_t, // Q*K types
  4540. typename qk8x8_t,
  4541. typename s_t, // soft-max types
  4542. typename s2_t,
  4543. typename s8x8_t,
  4544. typename o_t, // attention accumulation types
  4545. typename o4_t,
  4546. typename o8x8_t,
  4547. typename kd4x4_t, // key type in device memory
  4548. short nl_k,
  4549. void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
  4550. typename vd4x4_t, // value type in device memory
  4551. short nl_v,
  4552. void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
  4553. short DK, // K head size
  4554. short DV, // V head size
  4555. short Q = OP_FLASH_ATTN_EXT_NQPTG, // queries per threadgroup
  4556. short C = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup
  4557. kernel void kernel_flash_attn_ext(
  4558. constant ggml_metal_kargs_flash_attn_ext & args,
  4559. device const char * q,
  4560. device const char * k,
  4561. device const char * v,
  4562. device const char * mask,
  4563. device const char * sinks,
  4564. device const char * pad,
  4565. device const char * blk,
  4566. device char * dst,
  4567. threadgroup half * shmem_f16 [[threadgroup(0)]],
  4568. uint3 tgpig[[threadgroup_position_in_grid]],
  4569. ushort tiisg[[thread_index_in_simdgroup]],
  4570. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  4571. #define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C
  4572. #define FWD_ARGS args, q, k, v, mask, sinks, pad, blk, dst, shmem_f16, tgpig, tiisg, sgitg
  4573. switch (FC_flash_attn_ext_nsg) {
  4574. // note: disabled cases to reduce library load time
  4575. //case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
  4576. //case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
  4577. case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
  4578. }
  4579. #undef FWD_TMPL
  4580. #undef FWD_ARGS
  4581. }
  4582. // TODO: this is quite ugly. in the future these types will be hardcoded in the kernel, but for now keep them as
  4583. // template to be able to explore different combinations
  4584. //
  4585. #define FA_TYPES \
  4586. half, half4, simdgroup_half8x8, \
  4587. half, half4x4, simdgroup_half8x8, \
  4588. half, half4x4, simdgroup_half8x8, \
  4589. float, simdgroup_float8x8, \
  4590. float, float2, simdgroup_float8x8, \
  4591. float, float4, simdgroup_float8x8
  4592. //half, half4, simdgroup_half8x8
  4593. #define FA_TYPES_BF \
  4594. bfloat, bfloat4, simdgroup_bfloat8x8, \
  4595. bfloat, bfloat4x4, simdgroup_bfloat8x8, \
  4596. bfloat, bfloat4x4, simdgroup_bfloat8x8, \
  4597. float, simdgroup_float8x8, \
  4598. float, float2, simdgroup_float8x8, \
  4599. half, half4, simdgroup_half8x8
  4600. //float, float4, simdgroup_float8x8
  4601. #define FA_TYPES_F32 \
  4602. half, half4, simdgroup_half8x8, \
  4603. float, float4x4, simdgroup_float8x8, \
  4604. float, float4x4, simdgroup_float8x8, \
  4605. float, simdgroup_float8x8, \
  4606. float, float2, simdgroup_float8x8, \
  4607. float, float4, simdgroup_float8x8
  4608. //half, half4, simdgroup_half8x8
  4609. typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
  4610. template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 32, 32>;
  4611. template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 40, 40>;
  4612. template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 64, 64>;
  4613. template [[host_name("kernel_flash_attn_ext_f32_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 72, 72>;
  4614. template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 80, 80>;
  4615. template [[host_name("kernel_flash_attn_ext_f32_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 96, 96>;
  4616. template [[host_name("kernel_flash_attn_ext_f32_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 112, 112>;
  4617. template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 128, 128>;
  4618. template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 192>;
  4619. template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 128>;
  4620. template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 256, 256>;
  4621. template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 576, 512>;
  4622. template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
  4623. template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
  4624. template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
  4625. template [[host_name("kernel_flash_attn_ext_f16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 72, 72>;
  4626. template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
  4627. template [[host_name("kernel_flash_attn_ext_f16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96, 96>;
  4628. template [[host_name("kernel_flash_attn_ext_f16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112, 112>;
  4629. template [[host_name("kernel_flash_attn_ext_f16_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128, 128>;
  4630. template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 192>;
  4631. template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 128>;
  4632. template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256, 256>;
  4633. template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
  4634. #if defined(GGML_METAL_HAS_BF16)
  4635. template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 32, 32>;
  4636. template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
  4637. template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
  4638. template [[host_name("kernel_flash_attn_ext_bf16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 72, 72>;
  4639. template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
  4640. template [[host_name("kernel_flash_attn_ext_bf16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
  4641. template [[host_name("kernel_flash_attn_ext_bf16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
  4642. template [[host_name("kernel_flash_attn_ext_bf16_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
  4643. template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
  4644. template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
  4645. template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
  4646. template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
  4647. #endif
  4648. template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 32, 32>;
  4649. template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
  4650. template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
  4651. template [[host_name("kernel_flash_attn_ext_q4_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 72, 72>;
  4652. template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
  4653. template [[host_name("kernel_flash_attn_ext_q4_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96, 96>;
  4654. template [[host_name("kernel_flash_attn_ext_q4_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112, 112>;
  4655. template [[host_name("kernel_flash_attn_ext_q4_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128, 128>;
  4656. template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 192>;
  4657. template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>;
  4658. template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
  4659. template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
  4660. template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
  4661. template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
  4662. template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
  4663. template [[host_name("kernel_flash_attn_ext_q4_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 72, 72>;
  4664. template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
  4665. template [[host_name("kernel_flash_attn_ext_q4_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96, 96>;
  4666. template [[host_name("kernel_flash_attn_ext_q4_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112, 112>;
  4667. template [[host_name("kernel_flash_attn_ext_q4_1_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128, 128>;
  4668. template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 192>;
  4669. template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>;
  4670. template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
  4671. template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
  4672. template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
  4673. template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
  4674. template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
  4675. template [[host_name("kernel_flash_attn_ext_q5_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 72, 72>;
  4676. template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
  4677. template [[host_name("kernel_flash_attn_ext_q5_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96, 96>;
  4678. template [[host_name("kernel_flash_attn_ext_q5_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112, 112>;
  4679. template [[host_name("kernel_flash_attn_ext_q5_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128, 128>;
  4680. template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 192>;
  4681. template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>;
  4682. template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
  4683. template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
  4684. template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
  4685. template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
  4686. template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
  4687. template [[host_name("kernel_flash_attn_ext_q5_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 72, 72>;
  4688. template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
  4689. template [[host_name("kernel_flash_attn_ext_q5_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96, 96>;
  4690. template [[host_name("kernel_flash_attn_ext_q5_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112, 112>;
  4691. template [[host_name("kernel_flash_attn_ext_q5_1_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128, 128>;
  4692. template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 192>;
  4693. template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>;
  4694. template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
  4695. template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
  4696. template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
  4697. template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
  4698. template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
  4699. template [[host_name("kernel_flash_attn_ext_q8_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 72, 72>;
  4700. template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
  4701. template [[host_name("kernel_flash_attn_ext_q8_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96, 96>;
  4702. template [[host_name("kernel_flash_attn_ext_q8_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112, 112>;
  4703. template [[host_name("kernel_flash_attn_ext_q8_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128, 128>;
  4704. template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 192>;
  4705. template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>;
  4706. template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>;
  4707. template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
  4708. #undef FA_TYPES
  4709. #undef FA_TYPES_BF
  4710. #undef FA_TYPES_F32
  4711. constant bool FC_flash_attn_ext_vec_has_mask [[function_constant(FC_FLASH_ATTN_EXT_VEC + 0)]];
  4712. constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]];
  4713. constant bool FC_flash_attn_ext_vec_has_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 2)]];
  4714. constant bool FC_flash_attn_ext_vec_has_scap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 3)]];
  4715. constant bool FC_flash_attn_ext_vec_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT_VEC + 4)]];
  4716. //constant float FC_flash_attn_ext_vec_scale [[function_constant(FC_FLASH_ATTN_EXT_VEC + 10)]];
  4717. //constant float FC_flash_attn_ext_vec_max_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 11)]];
  4718. //constant float FC_flash_attn_ext_vec_logit_softcap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 12)]];
  4719. constant int32_t FC_flash_attn_ext_vec_ns10 [[function_constant(FC_FLASH_ATTN_EXT_VEC + 20)]];
  4720. constant int32_t FC_flash_attn_ext_vec_ns20 [[function_constant(FC_FLASH_ATTN_EXT_VEC + 21)]];
  4721. constant int32_t FC_flash_attn_ext_vec_nsg [[function_constant(FC_FLASH_ATTN_EXT_VEC + 22)]];
  4722. constant int32_t FC_flash_attn_ext_vec_nwg [[function_constant(FC_FLASH_ATTN_EXT_VEC + 23)]];
  4723. template<
  4724. typename q4_t, // query types in shared memory
  4725. typename k4_t, // key types in shared memory
  4726. typename v4_t, // value types in shared memory
  4727. typename qk_t, // Q*K types
  4728. typename s_t, // soft-max types
  4729. typename s4_t,
  4730. typename o4_t, // attention accumulation types
  4731. typename kd4_t, // key type in device memory
  4732. short nl_k,
  4733. void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
  4734. typename vd4_t, // value type in device memory
  4735. short nl_v,
  4736. void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
  4737. short DK, // K head size
  4738. short DV, // V head size
  4739. short NE, // head elements per thread
  4740. short Q, // queries per threadgroup
  4741. short C, // cache items per threadgroup
  4742. short NSG> // number of simd groups
  4743. void kernel_flash_attn_ext_vec_impl(
  4744. constant ggml_metal_kargs_flash_attn_ext_vec & args,
  4745. device const char * q,
  4746. device const char * k,
  4747. device const char * v,
  4748. device const char * mask,
  4749. device const char * sinks,
  4750. device const char * pad,
  4751. device char * dst,
  4752. threadgroup half * shmem_f16 [[threadgroup(0)]],
  4753. uint3 tgpig[[threadgroup_position_in_grid]],
  4754. ushort tiisg[[thread_index_in_simdgroup]],
  4755. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  4756. static_assert(DK % 32 == 0, "DK must be divisible by 32");
  4757. static_assert(DV % 32 == 0, "DV must be divisible by 32");
  4758. #define NWG (FC_flash_attn_ext_vec_nwg)
  4759. #define NS10 (FC_flash_attn_ext_vec_ns10)
  4760. #define NS20 (FC_flash_attn_ext_vec_ns20)
  4761. const short iwg = tgpig[2]%NWG;
  4762. const ushort iq3 = tgpig[2]/NWG;
  4763. const ushort iq2 = tgpig[1];
  4764. const ushort iq1 = tgpig[0];
  4765. constexpr short DK4 = DK/4;
  4766. constexpr short DV4 = DV/4;
  4767. constexpr short PK = PAD2(DK, 128);
  4768. constexpr short PK4 = PK/4;
  4769. constexpr short PV = PAD2(DV, 128);
  4770. constexpr short PV4 = PV/4;
  4771. constexpr short NW = N_SIMDWIDTH;
  4772. constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
  4773. constexpr short SH = 4*C; // shared memory per simdgroup
  4774. static_assert(DK4 % NL == 0, "DK4 must be divisible by NL");
  4775. static_assert(DV4 % NL == 0, "DV4 must be divisible by NL");
  4776. const short T = PK + NSG*SH; // shared memory size per query in (half)
  4777. //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data
  4778. threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t
  4779. threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*PK); // scratch buffer for attention
  4780. threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*PK); // same as above but in s4_t
  4781. threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + Q*PK); // scratch buffer for mask
  4782. threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + Q*T); // scratch buffer for the results
  4783. // store the result for all queries in shared memory (the O matrix from the paper)
  4784. so4 += tiisg;
  4785. {
  4786. q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03;
  4787. const short ikv2 = iq2/(args.ne02/args.ne_12_2);
  4788. const short ikv3 = iq3/(args.ne03/args.ne_12_3);
  4789. k += ikv2*args.nb12 + ikv3*args.nb13;
  4790. v += ikv2*args.nb22 + ikv3*args.nb23;
  4791. }
  4792. // load heads from Q to shared memory
  4793. device const float4 * q4 = (device const float4 *) ((device const char *) q);
  4794. for (short i = tiisg; i < PK4; i += NW) {
  4795. if (iq1 < args.ne01 && i < DK4) {
  4796. sq4[i] = (q4_t) q4[i];
  4797. } else {
  4798. sq4[i] = (q4_t) 0.0f;
  4799. }
  4800. }
  4801. // zero out so
  4802. for (short i = 0; i < DV4/NL; ++i) {
  4803. so4[i*NL] = (o4_t) 0.0f;
  4804. }
  4805. // zero out shared memory SH
  4806. for (short i = tiisg; i < SH/4; i += NW) {
  4807. ss4[i] = (s4_t) 0.0f;
  4808. }
  4809. threadgroup_barrier(mem_flags::mem_threadgroup);
  4810. {
  4811. float S = 0.0f;
  4812. float M = -FLT_MAX/2;
  4813. // thread indices inside the simdgroup
  4814. const short tx = tiisg%NL;
  4815. const short ty = tiisg/NL;
  4816. // pointer to the mask
  4817. device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
  4818. float slope = 1.0f;
  4819. // ALiBi
  4820. if (FC_flash_attn_ext_vec_has_bias) {
  4821. const short h = iq2;
  4822. const float base = h < args.n_head_log2 ? args.m0 : args.m1;
  4823. const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
  4824. slope = pow(base, exph);
  4825. }
  4826. // loop over the KV cache
  4827. // each simdgroup handles blocks of Q rows and C columns
  4828. for (int ic0 = iwg*NSG + sgitg; ; ic0 += NWG*NSG) {
  4829. int ic = ic0*C;
  4830. if (ic >= args.ne11) {
  4831. break;
  4832. }
  4833. // the last partial chunk uses the pad buffer as source
  4834. if (FC_flash_attn_ext_vec_has_kvpad && ic + C > args.ne11) {
  4835. k = pad;
  4836. v = k + args.nb11*C*args.ne_12_2*args.ne_12_3;
  4837. mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3;
  4838. const short ikv2 = iq2/(args.ne02/args.ne_12_2);
  4839. const short ikv3 = iq3/(args.ne03/args.ne_12_3);
  4840. k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C;
  4841. v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;
  4842. if (!FC_flash_attn_ext_vec_has_mask) {
  4843. if (ic + tiisg >= args.ne11) {
  4844. sm[tiisg] = -MAXHALF;
  4845. }
  4846. } else {
  4847. pm = (device const half *) (mask) +
  4848. iq1*C +
  4849. (iq2%args.ne32)*(C*args.ne31) +
  4850. (iq3%args.ne33)*(C*args.ne31*args.ne32);
  4851. }
  4852. ic = 0;
  4853. }
  4854. if (FC_flash_attn_ext_vec_has_mask) {
  4855. sm[tiisg] = pm[ic + tiisg];
  4856. }
  4857. // skip -INF blocks
  4858. if (simd_max(sm[tiisg]) == -INFINITY) {
  4859. continue;
  4860. }
  4861. // Q*K^T
  4862. {
  4863. device const k4_t * pk4 = (device const k4_t *) (k + ic*args.nb11);
  4864. threadgroup const q4_t * pq4 = sq4;
  4865. pk4 += ty*NS10/4 + tx;
  4866. pq4 += tx;
  4867. qk_t mqk[C/NE] = { [ 0 ... C/NE - 1] = 0.0f };
  4868. // each simdgroup processes 1 query and NE (NW/NL) cache elements
  4869. FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) {
  4870. if (is_same<kd4_t, k4_t>::value) {
  4871. FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) {
  4872. mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 + ii*NL], (float4) pq4[ii*NL]);
  4873. }
  4874. } else {
  4875. device const kd4_t * pk = (device const kd4_t *) (k + ((ic + NE*cc + ty)*args.nb11));
  4876. k4_t mk;
  4877. FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) {
  4878. const short i = ii*NL + tx;
  4879. deq_k_t4(pk + i/nl_k, i%nl_k, mk);
  4880. mqk[cc] += dot((float4) mk, (float4) sq4[i]);
  4881. }
  4882. }
  4883. if (NE == 1) {
  4884. mqk[cc] = simd_sum(mqk[cc]);
  4885. } else {
  4886. // simdgroup reduce (NE = 4)
  4887. // [ 0 .. 7] -> [ 0]
  4888. // [ 8 .. 15] -> [ 8]
  4889. // [16 .. 23] -> [16]
  4890. // [24 .. 31] -> [24]
  4891. if (NE <= 1) {
  4892. mqk[cc] += simd_shuffle_down(mqk[cc], 16);
  4893. }
  4894. if (NE <= 2) {
  4895. mqk[cc] += simd_shuffle_down(mqk[cc], 8);
  4896. }
  4897. if (NE <= 4) {
  4898. mqk[cc] += simd_shuffle_down(mqk[cc], 4);
  4899. }
  4900. if (NE <= 8) {
  4901. mqk[cc] += simd_shuffle_down(mqk[cc], 2);
  4902. }
  4903. if (NE <= 16) {
  4904. mqk[cc] += simd_shuffle_down(mqk[cc], 1);
  4905. }
  4906. // broadcast
  4907. mqk[cc] = simd_shuffle(mqk[cc], NL*ty);
  4908. }
  4909. }
  4910. if (FC_flash_attn_ext_vec_has_mask &&
  4911. !FC_flash_attn_ext_vec_has_scap &&
  4912. !FC_flash_attn_ext_vec_has_bias) {
  4913. ss[NE*tx + ty] = fma(mqk[tx], args.scale, (qk_t) sm[NE*tx + ty]);
  4914. } else {
  4915. mqk[tx] *= args.scale;
  4916. if (FC_flash_attn_ext_vec_has_scap) {
  4917. mqk[tx] = args.logit_softcap*precise::tanh(mqk[tx]);
  4918. }
  4919. if (FC_flash_attn_ext_vec_has_bias) {
  4920. mqk[tx] += (qk_t) sm[NE*tx + ty]*slope;
  4921. } else {
  4922. mqk[tx] += (qk_t) sm[NE*tx + ty];
  4923. }
  4924. ss[NE*tx + ty] = mqk[tx];
  4925. }
  4926. }
  4927. simdgroup_barrier(mem_flags::mem_threadgroup);
  4928. // online softmax
  4929. {
  4930. const float m = M;
  4931. const float s = ss[tiisg];
  4932. M = simd_max(max(M, s));
  4933. const float ms = exp(m - M);
  4934. const float vs = exp(s - M);
  4935. S = S*ms + simd_sum(vs);
  4936. // the P matrix from the paper (Q rows, C columns)
  4937. ss[tiisg] = vs;
  4938. // O = diag(ms)*O
  4939. if ((DV4/NL % NW == 0) || ty == 0) {
  4940. FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
  4941. so4[ii*NL] *= ms;
  4942. }
  4943. }
  4944. }
  4945. simdgroup_barrier(mem_flags::mem_threadgroup);
  4946. // O = O + (Q*K^T)*V
  4947. {
  4948. o4_t lo[DV4/NL];
  4949. FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
  4950. lo[ii] = 0.0f;
  4951. }
  4952. if (is_same<vd4_t, v4_t>::value) {
  4953. device const v4_t * pv4 = (device const v4_t *) (v + ic*args.nb21);
  4954. pv4 += ty*NS20/4 + tx;
  4955. const auto sst = ss + ty;
  4956. FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) {
  4957. FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
  4958. lo[ii] += o4_t(float4(pv4[cc*NE*NS20/4 + ii*NL])*float4(sst[cc*NE]));
  4959. }
  4960. }
  4961. } else {
  4962. FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) {
  4963. device const vd4_t * pv4 = (device const vd4_t *) (v + ((ic + NE*cc + ty)*args.nb21));
  4964. FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
  4965. const short i = ii*NL + tx;
  4966. v4_t mv;
  4967. deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);
  4968. lo[ii] += o4_t(float4(mv)*float4(ss[NE*cc + ty]));
  4969. }
  4970. }
  4971. }
  4972. FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
  4973. if (NE > 1) {
  4974. lo[ii][0] += simd_shuffle_down(lo[ii][0], 16);
  4975. lo[ii][1] += simd_shuffle_down(lo[ii][1], 16);
  4976. lo[ii][2] += simd_shuffle_down(lo[ii][2], 16);
  4977. lo[ii][3] += simd_shuffle_down(lo[ii][3], 16);
  4978. }
  4979. if (NE > 2) {
  4980. lo[ii][0] += simd_shuffle_down(lo[ii][0], 8);
  4981. lo[ii][1] += simd_shuffle_down(lo[ii][1], 8);
  4982. lo[ii][2] += simd_shuffle_down(lo[ii][2], 8);
  4983. lo[ii][3] += simd_shuffle_down(lo[ii][3], 8);
  4984. }
  4985. if (NE > 4) {
  4986. lo[ii][0] += simd_shuffle_down(lo[ii][0], 4);
  4987. lo[ii][1] += simd_shuffle_down(lo[ii][1], 4);
  4988. lo[ii][2] += simd_shuffle_down(lo[ii][2], 4);
  4989. lo[ii][3] += simd_shuffle_down(lo[ii][3], 4);
  4990. }
  4991. if (NE > 8) {
  4992. lo[ii][0] += simd_shuffle_down(lo[ii][0], 2);
  4993. lo[ii][1] += simd_shuffle_down(lo[ii][1], 2);
  4994. lo[ii][2] += simd_shuffle_down(lo[ii][2], 2);
  4995. lo[ii][3] += simd_shuffle_down(lo[ii][3], 2);
  4996. }
  4997. if (NE > 16) {
  4998. lo[ii][0] += simd_shuffle_down(lo[ii][0], 1);
  4999. lo[ii][1] += simd_shuffle_down(lo[ii][1], 1);
  5000. lo[ii][2] += simd_shuffle_down(lo[ii][2], 1);
  5001. lo[ii][3] += simd_shuffle_down(lo[ii][3], 1);
  5002. }
  5003. }
  5004. if ((DV4/NL % NW == 0) || ty == 0) {
  5005. FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
  5006. so4[ii*NL] += lo[ii];
  5007. }
  5008. }
  5009. }
  5010. }
  5011. if (FC_flash_attn_ext_vec_has_sinks && sgitg == 0 && iwg == 0) {
  5012. const float m = M;
  5013. const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
  5014. M = simd_max(max(M, s));
  5015. const float ms = exp(m - M);
  5016. const float vs = exp(s - M);
  5017. S = S*ms + simd_sum(vs);
  5018. if ((DV4/NL % NW == 0) || ty == 0) {
  5019. FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
  5020. so4[ii*NL] *= ms;
  5021. }
  5022. }
  5023. }
  5024. // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
  5025. if (tiisg == 0) {
  5026. ss[0] = (s_t) S;
  5027. ss[1] = (s_t) M;
  5028. }
  5029. }
  5030. so4 -= tiisg;
  5031. threadgroup_barrier(mem_flags::mem_threadgroup);
  5032. // parallel reduce
  5033. for (short r = NSG/2; r > 0; r >>= 1) {
  5034. if (sgitg < r) {
  5035. const float S0 = ss[ 0];
  5036. const float S1 = ss[r*(SH/2) + 0];
  5037. const float M0 = ss[ 1];
  5038. const float M1 = ss[r*(SH/2) + 1];
  5039. const float M = max(M0, M1);
  5040. const float ms0 = exp(M0 - M);
  5041. const float ms1 = exp(M1 - M);
  5042. const float S = S0*ms0 + S1*ms1;
  5043. if (tiisg == 0) {
  5044. ss[0] = S;
  5045. ss[1] = M;
  5046. }
  5047. // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
  5048. for (short i = tiisg; i < DV4; i += NW) {
  5049. so4[i] = so4[i]*ms0 + so4[i + r*PV4]*ms1;
  5050. }
  5051. }
  5052. threadgroup_barrier(mem_flags::mem_threadgroup);
  5053. }
  5054. // final rescale with 1/S and store to global memory
  5055. if (sgitg == 0) {
  5056. const int64_t nrows = args.ne3*args.ne2*args.ne1;
  5057. const int64_t rid = iq3*args.ne2*args.ne1 + iq2 + iq1*args.ne1;
  5058. device float4 * dst4 = (device float4 *) dst;
  5059. device float * dst1 = (device float *) dst + nrows*DV*NWG; // the S and M are stored after the results
  5060. const float S = NWG == 1 ? (ss[0] == 0.0f ? 0.0f : 1.0f/ss[0]) : 1.0f;
  5061. // interleave the workgroup data
  5062. for (short i = tiisg; i < DV4; i += NW) {
  5063. dst4[rid*DV4*NWG + NWG*i + iwg] = (float4) so4[i]*S;
  5064. }
  5065. // store S and M
  5066. if (NWG > 1) {
  5067. if (tiisg == 0) {
  5068. dst1[rid*(2*NWG) + 2*iwg + 0] = ss[0];
  5069. dst1[rid*(2*NWG) + 2*iwg + 1] = ss[1];
  5070. }
  5071. }
  5072. }
  5073. #undef NWG
  5074. #undef NS10
  5075. #undef NS20
  5076. }
  5077. template<
  5078. typename q4_t, // query types in shared memory
  5079. typename k4_t, // key types in shared memory
  5080. typename v4_t, // value types in shared memory
  5081. typename qk_t, // Q*K types
  5082. typename s_t, // soft-max types
  5083. typename s4_t,
  5084. typename o4_t, // attention accumulation types
  5085. typename kd4_t, // key type in device memory
  5086. short nl_k,
  5087. void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
  5088. typename vd4_t, // value type in device memory
  5089. short nl_v,
  5090. void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
  5091. short DK, // K head size
  5092. short DV, // V head size
  5093. short NE = 4, // head elements per thread
  5094. short Q = OP_FLASH_ATTN_EXT_VEC_NQPTG, // queries per threadgroup
  5095. short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup
  5096. kernel void kernel_flash_attn_ext_vec(
  5097. constant ggml_metal_kargs_flash_attn_ext_vec & args,
  5098. device const char * q,
  5099. device const char * k,
  5100. device const char * v,
  5101. device const char * mask,
  5102. device const char * sinks,
  5103. device const char * pad,
  5104. device char * dst,
  5105. threadgroup half * shmem_f16 [[threadgroup(0)]],
  5106. uint3 tgpig[[threadgroup_position_in_grid]],
  5107. ushort tiisg[[thread_index_in_simdgroup]],
  5108. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  5109. #define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C
  5110. #define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg
  5111. switch (FC_flash_attn_ext_vec_nsg) {
  5112. // note: disabled cases to reduce library load time
  5113. case 1: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 1>(FWD_ARGS); break;
  5114. case 2: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 2>(FWD_ARGS); break;
  5115. case 4: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 4>(FWD_ARGS); break;
  5116. //case 8: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 8>(FWD_ARGS); break;
  5117. //case 16: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 16>(FWD_ARGS); break;
  5118. //case 32: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 32>(FWD_ARGS); break;
  5119. }
  5120. #undef FWD_TMPL
  5121. #undef FWD_ARGS
  5122. }
  5123. // note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem
  5124. // in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
  5125. //
  5126. #define FA_TYPES \
  5127. half4, \
  5128. half4, \
  5129. half4, \
  5130. float, \
  5131. float, float4, \
  5132. float4
  5133. #define FA_TYPES_F32 \
  5134. half4, \
  5135. float4, \
  5136. float4, \
  5137. float, \
  5138. float, float4, \
  5139. float4
  5140. typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
  5141. template [[host_name("kernel_flash_attn_ext_vec_f32_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 32, 32, 4>;
  5142. template [[host_name("kernel_flash_attn_ext_vec_f16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 32, 32, 4>;
  5143. #if defined(GGML_METAL_HAS_BF16)
  5144. template [[host_name("kernel_flash_attn_ext_vec_bf16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 32, 32, 4>;
  5145. #endif
  5146. template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 32, 32, 4>;
  5147. template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 32, 32, 4>;
  5148. template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 32, 32, 4>;
  5149. template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 32, 32, 4>;
  5150. template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 32, 32, 4>;
  5151. template [[host_name("kernel_flash_attn_ext_vec_f32_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 64, 64, 2>;
  5152. template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 2>;
  5153. #if defined(GGML_METAL_HAS_BF16)
  5154. template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 2>;
  5155. #endif
  5156. template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 64, 64, 2>;
  5157. template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 64, 64, 2>;
  5158. template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 64, 64, 2>;
  5159. template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 2>;
  5160. template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 2>;
  5161. template [[host_name("kernel_flash_attn_ext_vec_f32_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 96, 96, 4>;
  5162. template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
  5163. #if defined(GGML_METAL_HAS_BF16)
  5164. template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>;
  5165. #endif
  5166. template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 96, 96, 4>;
  5167. template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 96, 96, 4>;
  5168. template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 96, 96, 4>;
  5169. template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>;
  5170. template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>;
  5171. template [[host_name("kernel_flash_attn_ext_vec_f32_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 128, 128, 1>;
  5172. template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 1>;
  5173. #if defined(GGML_METAL_HAS_BF16)
  5174. template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 1>;
  5175. #endif
  5176. template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 128, 128, 1>;
  5177. template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 128, 128, 1>;
  5178. template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 128, 128, 1>;
  5179. template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 1>;
  5180. template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 1>;
  5181. template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 192, 2>;
  5182. template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 2>;
  5183. #if defined(GGML_METAL_HAS_BF16)
  5184. template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 2>;
  5185. #endif
  5186. template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 192, 2>;
  5187. template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 192, 2>;
  5188. template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 192, 2>;
  5189. template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 2>;
  5190. template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 2>;
  5191. template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 128, 2>;
  5192. template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 2>;
  5193. #if defined(GGML_METAL_HAS_BF16)
  5194. template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 2>;
  5195. #endif
  5196. template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 128, 2>;
  5197. template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 128, 2>;
  5198. template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 128, 2>;
  5199. template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 2>;
  5200. template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 2>;
  5201. template [[host_name("kernel_flash_attn_ext_vec_f32_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 256, 256, 1>;
  5202. template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 1>;
  5203. #if defined(GGML_METAL_HAS_BF16)
  5204. template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 1>;
  5205. #endif
  5206. template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 256, 256, 1>;
  5207. template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 256, 256, 1>;
  5208. template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 256, 256, 1>;
  5209. template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>;
  5210. template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>;
  5211. template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 576, 512, 2>;
  5212. template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
  5213. #if defined(GGML_METAL_HAS_BF16)
  5214. template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>;
  5215. #endif
  5216. template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 576, 512, 2>;
  5217. template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 576, 512, 2>;
  5218. template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 576, 512, 2>;
  5219. template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 576, 512, 2>;
  5220. template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 576, 512, 2>;
  5221. #undef FA_TYPES
  5222. #undef FA_TYPES_F32
  5223. constant int32_t FC_flash_attn_ext_vec_reduce_DV [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 0)]];
  5224. constant int32_t FC_flash_attn_ext_vec_reduce_NWG [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 1)]];
  5225. kernel void kernel_flash_attn_ext_vec_reduce(
  5226. constant ggml_metal_kargs_flash_attn_ext_vec_reduce & args,
  5227. device const char * htmp,
  5228. device char * dst,
  5229. uint tgpig[[threadgroup_position_in_grid]],
  5230. ushort tiisg[[thread_index_in_simdgroup]],
  5231. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  5232. #define NWG (FC_flash_attn_ext_vec_reduce_NWG)
  5233. #define DV (FC_flash_attn_ext_vec_reduce_DV)
  5234. const uint64_t rid = tgpig;
  5235. const short iwg = tiisg;
  5236. device const float * ss = (device const float *) htmp + (uint64_t)args.nrows*DV*NWG;
  5237. float S = ss[rid*(2*NWG) + 2*iwg + 0];
  5238. float M = ss[rid*(2*NWG) + 2*iwg + 1];
  5239. const float m = simd_max(M);
  5240. const float ms = exp(M - m);
  5241. S = simd_sum(S*ms);
  5242. S = S == 0.0f ? 0.0f : 1.0f/S;
  5243. const short DV4 = DV/4;
  5244. device const float4 * htmp4 = (device const float4 *) htmp + rid*DV4*NWG;
  5245. device float4 * dst4 = (device float4 *) dst + rid*DV4;
  5246. for (short i = sgitg; i < DV4; i += NWG) {
  5247. const float4 v = simd_sum(htmp4[i*NWG + iwg]*ms);
  5248. if (iwg == 0) {
  5249. dst4[i] = v*S;
  5250. }
  5251. }
  5252. #undef NWG
  5253. #undef DV
  5254. }
  5255. template<typename T0, typename T1>
  5256. kernel void kernel_cpy_t_t(
  5257. constant ggml_metal_kargs_cpy & args,
  5258. device const char * src0,
  5259. device char * dst,
  5260. uint3 tgpig[[threadgroup_position_in_grid]],
  5261. ushort tiitg[[thread_index_in_threadgroup]],
  5262. ushort3 ntg[[threads_per_threadgroup]]) {
  5263. const int i03 = tgpig[2];
  5264. const int i02 = tgpig[1];
  5265. const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
  5266. const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
  5267. const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
  5268. const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
  5269. const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
  5270. const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
  5271. const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
  5272. device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
  5273. for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) {
  5274. device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
  5275. dst_data[i00] = (T1) src[0];
  5276. break;
  5277. }
  5278. }
  5279. typedef decltype(kernel_cpy_t_t<float, float>) kernel_cpy_t;
  5280. template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, float>;
  5281. template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, half>;
  5282. template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, int32_t>;
  5283. template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, float>;
  5284. template [[host_name("kernel_cpy_i32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, int32_t>;
  5285. #if defined(GGML_METAL_HAS_BF16)
  5286. template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, bfloat>;
  5287. #endif
  5288. template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<half, float>;
  5289. template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<half, half>;
  5290. #if defined(GGML_METAL_HAS_BF16)
  5291. template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, float>;
  5292. template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, bfloat>;
  5293. #endif
  5294. template<short QK,
  5295. typename block_q,
  5296. void (*quantize_func)(device const float *, device block_q &)>
  5297. kernel void kernel_cpy_f32_q(
  5298. constant ggml_metal_kargs_cpy & args,
  5299. device const char * src0,
  5300. device char * dst,
  5301. uint3 tgpig[[threadgroup_position_in_grid]],
  5302. ushort tiitg[[thread_index_in_threadgroup]],
  5303. ushort3 ntg[[threads_per_threadgroup]]) {
  5304. const int i03 = tgpig[2];
  5305. const int i02 = tgpig[1];
  5306. const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
  5307. const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
  5308. const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
  5309. const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
  5310. const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
  5311. const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
  5312. const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
  5313. device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
  5314. for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
  5315. device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00);
  5316. quantize_func(src, dst_data[i00]);
  5317. break;
  5318. }
  5319. }
  5320. typedef decltype(kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>) cpy_f_q_t;
  5321. template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>;
  5322. template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0, block_q4_0, quantize_q4_0>;
  5323. template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1, block_q4_1, quantize_q4_1>;
  5324. template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0, block_q5_0, quantize_q5_0>;
  5325. template [[host_name("kernel_cpy_f32_q5_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_1, block_q5_1, quantize_q5_1>;
  5326. template [[host_name("kernel_cpy_f32_iq4_nl")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_NL, block_iq4_nl, quantize_iq4_nl>;
  5327. template<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
  5328. kernel void kernel_cpy_q_f32(
  5329. constant ggml_metal_kargs_cpy & args,
  5330. device const char * src0,
  5331. device char * dst,
  5332. uint3 tgpig[[threadgroup_position_in_grid]],
  5333. ushort tiitg[[thread_index_in_threadgroup]],
  5334. ushort3 ntg[[threads_per_threadgroup]]) {
  5335. const int i03 = tgpig[2];
  5336. const int i02 = tgpig[1];
  5337. const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
  5338. const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
  5339. const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
  5340. const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
  5341. const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
  5342. const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
  5343. const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
  5344. device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
  5345. device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
  5346. for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
  5347. T4x4 temp;
  5348. dequantize_func(src_data + i00/nl, i00%nl, temp);
  5349. dst_data[i00] = temp;
  5350. break;
  5351. }
  5352. }
  5353. typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
  5354. template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
  5355. template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
  5356. template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
  5357. template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
  5358. template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
  5359. template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
  5360. template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
  5361. template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
  5362. template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_1, 2, dequantize_q5_1>;
  5363. template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q8_0, 2, dequantize_q8_0>;
  5364. kernel void kernel_concat(
  5365. constant ggml_metal_kargs_concat & args,
  5366. device const char * src0,
  5367. device const char * src1,
  5368. device char * dst,
  5369. uint3 tgpig[[threadgroup_position_in_grid]],
  5370. ushort3 tpitg[[thread_position_in_threadgroup]],
  5371. ushort3 ntg[[threads_per_threadgroup]]) {
  5372. const int i3 = tgpig.z;
  5373. const int i2 = tgpig.y;
  5374. const int i1 = tgpig.x;
  5375. int o[4] = {0, 0, 0, 0};
  5376. o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03));
  5377. device const float * x;
  5378. for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
  5379. if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
  5380. x = (device const float *)(src0 + (i3 )*args.nb03 + (i2 )*args.nb02 + (i1 )*args.nb01 + (i0 )*args.nb00);
  5381. } else {
  5382. x = (device const float *)(src1 + (i3 - o[3])*args.nb13 + (i2 - o[2])*args.nb12 + (i1 - o[1])*args.nb11 + (i0 - o[0])*args.nb10);
  5383. }
  5384. device float * y = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
  5385. *y = *x;
  5386. }
  5387. }
  5388. template<int nr0, typename args_t>
  5389. void kernel_mul_mv_q2_K_f32_impl(
  5390. args_t args,
  5391. device const char * src0,
  5392. device const char * src1,
  5393. device char * dst,
  5394. threadgroup char * shmem,
  5395. uint3 tgpig,
  5396. ushort tiisg,
  5397. ushort sgitg) {
  5398. const short NSG = FC_mul_mv_nsg;
  5399. const int nb = args.ne00/QK_K;
  5400. const int r0 = tgpig.x;
  5401. const int r1 = tgpig.y;
  5402. const int im = tgpig.z;
  5403. const int first_row = (r0 * NSG + sgitg) * nr0;
  5404. const uint i12 = im%args.ne12;
  5405. const uint i13 = im/args.ne12;
  5406. const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
  5407. const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
  5408. device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0);
  5409. device const float * y = (device const float *) (src1 + offset1);
  5410. float yl[32];
  5411. float sumf[nr0]={0.f};
  5412. const short ix = tiisg/8; // 0...3
  5413. const short it = tiisg%8; // 0...7
  5414. const short iq = it/4; // 0 or 1
  5415. const short ir = it%4; // 0...3
  5416. const short is = (8*ir)/16;// 0 or 1
  5417. device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
  5418. for (int ib = ix; ib < nb; ib += 4) {
  5419. float4 sumy = {0.f, 0.f, 0.f, 0.f};
  5420. for (short i = 0; i < 8; ++i) {
  5421. yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
  5422. yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
  5423. yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
  5424. yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
  5425. }
  5426. device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is;
  5427. device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
  5428. device const half * dh = &x[ib].d;
  5429. for (short row = 0; row < nr0; row++) {
  5430. float4 acc1 = {0.f, 0.f, 0.f, 0.f};
  5431. float4 acc2 = {0.f, 0.f, 0.f, 0.f};
  5432. for (int i = 0; i < 8; i += 2) {
  5433. acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
  5434. acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
  5435. acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
  5436. acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
  5437. acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
  5438. acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
  5439. acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
  5440. acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
  5441. }
  5442. float dall = dh[0];
  5443. float dmin = dh[1] * 1.f/16.f;
  5444. sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
  5445. (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
  5446. (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
  5447. (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
  5448. dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
  5449. qs += args.nb01/2;
  5450. sc += args.nb01;
  5451. dh += args.nb01/2;
  5452. }
  5453. y4 += 4 * QK_K;
  5454. }
  5455. device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
  5456. for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
  5457. float sum_all = simd_sum(sumf[row]);
  5458. if (tiisg == 0) {
  5459. dst_f32[first_row + row] = sum_all;
  5460. }
  5461. }
  5462. }
  5463. [[host_name("kernel_mul_mv_q2_K_f32")]]
  5464. kernel void kernel_mul_mv_q2_K_f32(
  5465. constant ggml_metal_kargs_mul_mv & args,
  5466. device const char * src0,
  5467. device const char * src1,
  5468. device char * dst,
  5469. uint3 tgpig[[threadgroup_position_in_grid]],
  5470. ushort tiisg[[thread_index_in_simdgroup]],
  5471. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  5472. kernel_mul_mv_q2_K_f32_impl<N_R0_Q2_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
  5473. }
  5474. template<int nr0, typename args_t>
  5475. void kernel_mul_mv_q3_K_f32_impl(
  5476. args_t args,
  5477. device const char * src0,
  5478. device const char * src1,
  5479. device char * dst,
  5480. threadgroup char * shmem,
  5481. uint3 tgpig,
  5482. ushort tiisg,
  5483. ushort sgitg) {
  5484. const short NSG = FC_mul_mv_nsg;
  5485. const int nb = args.ne00/QK_K;
  5486. const int r0 = tgpig.x;
  5487. const int r1 = tgpig.y;
  5488. const int im = tgpig.z;
  5489. const int first_row = (r0 * NSG + sgitg) * nr0;
  5490. const uint i12 = im%args.ne12;
  5491. const uint i13 = im/args.ne12;
  5492. const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
  5493. const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
  5494. device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0);
  5495. device const float * yy = (device const float *) (src1 + offset1);
  5496. float yl[32];
  5497. //const uint16_t kmask1 = 0x3030;
  5498. //const uint16_t kmask2 = 0x0f0f;
  5499. const short tid = tiisg/4;
  5500. const short ix = tiisg%4;
  5501. const short ip = tid/4; // 0 or 1
  5502. const short il = 2*((tid%4)/2); // 0 or 2
  5503. const short ir = tid%2;
  5504. const short l0 = 8*ir;
  5505. // One would think that the Metal compiler would figure out that ip and il can only have
  5506. // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
  5507. // with these two tales.
  5508. //
  5509. // Possible masks for the high bit
  5510. const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0
  5511. {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2
  5512. {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0
  5513. {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
  5514. // Possible masks for the low 2 bits
  5515. const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
  5516. const ushort4 hm = mm[2*ip + il/2];
  5517. const short shift = 2*il;
  5518. const float v1 = il == 0 ? 4.f : 64.f;
  5519. const float v2 = 4.f * v1;
  5520. const uint16_t s_shift1 = 4*ip;
  5521. const uint16_t s_shift2 = s_shift1 + il;
  5522. const short q_offset = 32*ip + l0;
  5523. const short y_offset = 128*ip + 32*il + l0;
  5524. device const float * y1 = yy + ix*QK_K + y_offset;
  5525. uint32_t scales32, aux32;
  5526. thread uint16_t * scales16 = (thread uint16_t *)&scales32;
  5527. thread const int8_t * scales = (thread const int8_t *)&scales32;
  5528. float sumf1[nr0] = {0.f};
  5529. float sumf2[nr0] = {0.f};
  5530. for (int i = ix; i < nb; i += 4) {
  5531. for (short l = 0; l < 8; ++l) {
  5532. yl[l+ 0] = y1[l+ 0];
  5533. yl[l+ 8] = y1[l+16];
  5534. yl[l+16] = y1[l+32];
  5535. yl[l+24] = y1[l+48];
  5536. }
  5537. device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
  5538. device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
  5539. device const uint16_t * a = (device const uint16_t *)(x[i].scales);
  5540. device const half * dh = &x[i].d;
  5541. for (short row = 0; row < nr0; ++row) {
  5542. const float d_all = (float)dh[0];
  5543. scales16[0] = a[4];
  5544. scales16[1] = a[5];
  5545. aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
  5546. scales16[0] = a[il+0];
  5547. scales16[1] = a[il+1];
  5548. scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
  5549. float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
  5550. for (short l = 0; l < 8; l += 2) {
  5551. const int32_t qs = q[l/2];
  5552. s1 += yl[l+0] * (qs & qm[il/2][0]);
  5553. s2 += yl[l+1] * (qs & qm[il/2][1]);
  5554. s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
  5555. s4 += yl[l+16] * (qs & qm[il/2][2]);
  5556. s5 += yl[l+17] * (qs & qm[il/2][3]);
  5557. s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
  5558. }
  5559. float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
  5560. float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
  5561. sumf1[row] += d1 * (scales[0] - 32);
  5562. sumf2[row] += d2 * (scales[2] - 32);
  5563. s1 = s2 = s3 = s4 = s5 = s6 = 0;
  5564. for (short l = 0; l < 8; l += 2) {
  5565. const int32_t qs = q[l/2+8];
  5566. s1 += yl[l+8] * (qs & qm[il/2][0]);
  5567. s2 += yl[l+9] * (qs & qm[il/2][1]);
  5568. s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
  5569. s4 += yl[l+24] * (qs & qm[il/2][2]);
  5570. s5 += yl[l+25] * (qs & qm[il/2][3]);
  5571. s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
  5572. }
  5573. d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
  5574. d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
  5575. sumf1[row] += d1 * (scales[1] - 32);
  5576. sumf2[row] += d2 * (scales[3] - 32);
  5577. q += args.nb01/2;
  5578. h += args.nb01/2;
  5579. a += args.nb01/2;
  5580. dh += args.nb01/2;
  5581. }
  5582. y1 += 4 * QK_K;
  5583. }
  5584. for (int row = 0; row < nr0; ++row) {
  5585. const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
  5586. sumf1[row] = simd_sum(sumf);
  5587. }
  5588. device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
  5589. if (tiisg == 0) {
  5590. for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
  5591. dst_f32[first_row + row] = sumf1[row];
  5592. }
  5593. }
  5594. }
  5595. [[host_name("kernel_mul_mv_q3_K_f32")]]
  5596. kernel void kernel_mul_mv_q3_K_f32(
  5597. constant ggml_metal_kargs_mul_mv & args,
  5598. device const char * src0,
  5599. device const char * src1,
  5600. device char * dst,
  5601. uint3 tgpig[[threadgroup_position_in_grid]],
  5602. ushort tiisg[[thread_index_in_simdgroup]],
  5603. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  5604. kernel_mul_mv_q3_K_f32_impl<N_R0_Q3_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
  5605. }
  5606. template<int nr0, typename args_t>
  5607. void kernel_mul_mv_q4_K_f32_impl(
  5608. args_t args,
  5609. device const char * src0,
  5610. device const char * src1,
  5611. device char * dst,
  5612. threadgroup char * shmem,
  5613. uint3 tgpig,
  5614. ushort tiisg,
  5615. ushort sgitg) {
  5616. const short NSG = FC_mul_mv_nsg;
  5617. constexpr uint16_t kmask1 = 0x3f3f;
  5618. constexpr uint16_t kmask2 = 0x0f0f;
  5619. constexpr uint16_t kmask3 = 0xc0c0;
  5620. const short ix = tiisg/8; // 0...3
  5621. const short it = tiisg%8; // 0...7
  5622. const short iq = it/4; // 0 or 1
  5623. const short ir = it%4; // 0...3
  5624. const int nb = args.ne00/QK_K;
  5625. const int r0 = tgpig.x;
  5626. const int r1 = tgpig.y;
  5627. const int im = tgpig.z;
  5628. const int first_row = (r0 * NSG + sgitg) * nr0;
  5629. const uint i12 = im%args.ne12;
  5630. const uint i13 = im/args.ne12;
  5631. const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
  5632. const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
  5633. device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0);
  5634. device const float * y = (device const float *) (src1 + offset1);
  5635. float yl[16];
  5636. float yh[16];
  5637. float sumf[nr0]={0.f};
  5638. device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
  5639. uint16_t sc16[4];
  5640. thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
  5641. for (int ib = ix; ib < nb; ib += 4) {
  5642. float4 sumy = {0.f, 0.f, 0.f, 0.f};
  5643. for (short i = 0; i < 8; ++i) {
  5644. yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
  5645. yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
  5646. yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
  5647. yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
  5648. }
  5649. device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;
  5650. device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
  5651. device const half * dh = &x[ib].d;
  5652. for (short row = 0; row < nr0; row++) {
  5653. sc16[0] = sc[0] & kmask1;
  5654. sc16[1] = sc[2] & kmask1;
  5655. sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
  5656. sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
  5657. device const uint16_t * q2 = q1 + 32;
  5658. float4 acc1 = {0.f, 0.f, 0.f, 0.f};
  5659. float4 acc2 = {0.f, 0.f, 0.f, 0.f};
  5660. FOR_UNROLL (short i = 0; i < 4; ++i) {
  5661. acc1[0] += yl[2*i + 0] * (q1[i] & 0x000F);
  5662. acc1[1] += yl[2*i + 1] * (q1[i] & 0x0F00);
  5663. acc1[2] += yl[2*i + 8] * (q1[i] & 0x00F0);
  5664. acc1[3] += yl[2*i + 9] * (q1[i] & 0xF000);
  5665. acc2[0] += yh[2*i + 0] * (q2[i] & 0x000F);
  5666. acc2[1] += yh[2*i + 1] * (q2[i] & 0x0F00);
  5667. acc2[2] += yh[2*i + 8] * (q2[i] & 0x00F0);
  5668. acc2[3] += yh[2*i + 9] * (q2[i] & 0xF000);
  5669. }
  5670. sumf[row] += dh[0] * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
  5671. (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
  5672. (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
  5673. (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
  5674. dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
  5675. q1 += args.nb01/2;
  5676. sc += args.nb01/2;
  5677. dh += args.nb01/2;
  5678. }
  5679. y4 += 4 * QK_K;
  5680. }
  5681. device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
  5682. for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
  5683. float sum_all = simd_sum(sumf[row]);
  5684. if (tiisg == 0) {
  5685. dst_f32[first_row + row] = sum_all;
  5686. }
  5687. }
  5688. }
  5689. [[host_name("kernel_mul_mv_q4_K_f32")]]
  5690. kernel void kernel_mul_mv_q4_K_f32(
  5691. constant ggml_metal_kargs_mul_mv & args,
  5692. device const char * src0,
  5693. device const char * src1,
  5694. device char * dst,
  5695. uint3 tgpig[[threadgroup_position_in_grid]],
  5696. ushort tiisg[[thread_index_in_simdgroup]],
  5697. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  5698. kernel_mul_mv_q4_K_f32_impl<N_R0_Q4_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
  5699. }
  5700. template<int nr0, typename args_t>
  5701. void kernel_mul_mv_q5_K_f32_impl(
  5702. args_t args,
  5703. device const char * src0,
  5704. device const char * src1,
  5705. device char * dst,
  5706. threadgroup char * shmem,
  5707. uint3 tgpig,
  5708. ushort tiisg,
  5709. ushort sgitg) {
  5710. const short NSG = FC_mul_mv_nsg;
  5711. const int nb = args.ne00/QK_K;
  5712. const int r0 = tgpig.x;
  5713. const int r1 = tgpig.y;
  5714. const int im = tgpig.z;
  5715. const int first_row = (r0 * NSG + sgitg) * nr0;
  5716. const uint i12 = im%args.ne12;
  5717. const uint i13 = im/args.ne12;
  5718. const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
  5719. const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
  5720. device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0);
  5721. device const float * yy = (device const float *) (src1 + offset1);
  5722. float sumf[nr0]={0.f};
  5723. float yl[16], yh[16];
  5724. constexpr uint16_t kmask1 = 0x3f3f;
  5725. constexpr uint16_t kmask2 = 0x0f0f;
  5726. constexpr uint16_t kmask3 = 0xc0c0;
  5727. const short tid = tiisg/4;
  5728. const short ix = tiisg%4;
  5729. const short iq = tid/4;
  5730. const short ir = tid%4;
  5731. const short l0 = 8*ir;
  5732. const short q_offset = 32*iq + l0;
  5733. const short y_offset = 64*iq + l0;
  5734. const uint8_t hm1 = 1u << (2*iq);
  5735. const uint8_t hm2 = hm1 << 1;
  5736. const uint8_t hm3 = hm1 << 4;
  5737. const uint8_t hm4 = hm2 << 4;
  5738. uint16_t sc16[4];
  5739. thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
  5740. device const float * y1 = yy + ix*QK_K + y_offset;
  5741. for (int i = ix; i < nb; i += 4) {
  5742. device const uint8_t * q1 = x[i].qs + q_offset;
  5743. device const uint8_t * qh = x[i].qh + l0;
  5744. device const half * dh = &x[i].d;
  5745. device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;
  5746. device const float * y2 = y1 + 128;
  5747. float4 sumy = {0.f, 0.f, 0.f, 0.f};
  5748. for (short l = 0; l < 8; ++l) {
  5749. yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
  5750. yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
  5751. yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
  5752. yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
  5753. }
  5754. for (short row = 0; row < nr0; ++row) {
  5755. device const uint8_t * q2 = q1 + 64;
  5756. sc16[0] = a[0] & kmask1;
  5757. sc16[1] = a[2] & kmask1;
  5758. sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
  5759. sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
  5760. float4 acc1 = {0.f};
  5761. float4 acc2 = {0.f};
  5762. FOR_UNROLL (short l = 0; l < 8; ++l) {
  5763. uint8_t h = qh[l];
  5764. acc1[0] += yl[l+0] * (q1[l] & 0x0F);
  5765. acc1[1] += yl[l+8] * (q1[l] & 0xF0);
  5766. acc1[2] += yh[l+0] * (q2[l] & 0x0F);
  5767. acc1[3] += yh[l+8] * (q2[l] & 0xF0);
  5768. acc2[0] += h & hm1 ? yl[l+0] : 0.f;
  5769. acc2[1] += h & hm2 ? yl[l+8] : 0.f;
  5770. acc2[2] += h & hm3 ? yh[l+0] : 0.f;
  5771. acc2[3] += h & hm4 ? yh[l+8] : 0.f;
  5772. }
  5773. sumf[row] += dh[0] * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
  5774. sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
  5775. sc8[4] * (acc1[2] + 16.f*acc2[2]) +
  5776. sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
  5777. dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
  5778. q1 += args.nb01;
  5779. qh += args.nb01;
  5780. dh += args.nb01/2;
  5781. a += args.nb01/2;
  5782. }
  5783. y1 += 4 * QK_K;
  5784. }
  5785. device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
  5786. for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
  5787. const float tot = simd_sum(sumf[row]);
  5788. if (tiisg == 0) {
  5789. dst_f32[first_row + row] = tot;
  5790. }
  5791. }
  5792. }
  5793. [[host_name("kernel_mul_mv_q5_K_f32")]]
  5794. kernel void kernel_mul_mv_q5_K_f32(
  5795. constant ggml_metal_kargs_mul_mv & args,
  5796. device const char * src0,
  5797. device const char * src1,
  5798. device char * dst,
  5799. uint3 tgpig[[threadgroup_position_in_grid]],
  5800. ushort tiisg[[thread_index_in_simdgroup]],
  5801. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  5802. kernel_mul_mv_q5_K_f32_impl<N_R0_Q5_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
  5803. }
  5804. template<int nr0, typename args_t>
  5805. void kernel_mul_mv_q6_K_f32_impl(
  5806. args_t args,
  5807. device const char * src0,
  5808. device const char * src1,
  5809. device char * dst,
  5810. threadgroup char * shmem,
  5811. uint3 tgpig,
  5812. ushort tiisg,
  5813. ushort sgitg) {
  5814. const short NSG = FC_mul_mv_nsg;
  5815. constexpr uint8_t kmask1 = 0x03;
  5816. constexpr uint8_t kmask2 = 0x0C;
  5817. constexpr uint8_t kmask3 = 0x30;
  5818. constexpr uint8_t kmask4 = 0xC0;
  5819. const int nb = args.ne00/QK_K;
  5820. const int r0 = tgpig.x;
  5821. const int r1 = tgpig.y;
  5822. const int im = tgpig.z;
  5823. const int first_row = (r0 * NSG + sgitg) * nr0;
  5824. const uint i12 = im%args.ne12;
  5825. const uint i13 = im/args.ne12;
  5826. const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
  5827. const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
  5828. device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
  5829. device const float * yy = (device const float *) (src1 + offset1);
  5830. float sumf[nr0] = { 0.f };
  5831. float yl[16];
  5832. const short tid = tiisg/2;
  5833. const short ix = tiisg%2;
  5834. const short ip = tid/8; // 0 or 1
  5835. const short il = tid%8;
  5836. const short l0 = 4*il;
  5837. const short is = 8*ip + l0/16;
  5838. const short y_offset = 128*ip + l0;
  5839. const short q_offset_l = 64*ip + l0;
  5840. const short q_offset_h = 32*ip + l0;
  5841. for (int i = ix; i < nb; i += 2) {
  5842. device const uint8_t * q1 = x[i].ql + q_offset_l;
  5843. device const uint8_t * q2 = q1 + 32;
  5844. device const uint8_t * qh = x[i].qh + q_offset_h;
  5845. device const int8_t * sc = x[i].scales + is;
  5846. device const half * dh = &x[i].d;
  5847. device const float * y = yy + i * QK_K + y_offset;
  5848. for (short l = 0; l < 4; ++l) {
  5849. yl[4*l + 0] = y[l + 0];
  5850. yl[4*l + 1] = y[l + 32];
  5851. yl[4*l + 2] = y[l + 64];
  5852. yl[4*l + 3] = y[l + 96];
  5853. }
  5854. for (short row = 0; row < nr0; ++row) {
  5855. float4 sums = {0.f, 0.f, 0.f, 0.f};
  5856. FOR_UNROLL (short l = 0; l < 4; ++l) {
  5857. sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
  5858. sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
  5859. sums[2] += yl[4*l + 2] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
  5860. sums[3] += yl[4*l + 3] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
  5861. }
  5862. sumf[row] += dh[0] * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
  5863. q1 += args.nb01;
  5864. q2 += args.nb01;
  5865. qh += args.nb01;
  5866. sc += args.nb01;
  5867. dh += args.nb01/2;
  5868. }
  5869. }
  5870. device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
  5871. for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
  5872. float sum_all = simd_sum(sumf[row]);
  5873. if (tiisg == 0) {
  5874. dst_f32[first_row + row] = sum_all;
  5875. }
  5876. }
  5877. }
  5878. [[host_name("kernel_mul_mv_q6_K_f32")]]
  5879. kernel void kernel_mul_mv_q6_K_f32(
  5880. constant ggml_metal_kargs_mul_mv & args,
  5881. device const char * src0,
  5882. device const char * src1,
  5883. device char * dst,
  5884. uint3 tgpig[[threadgroup_position_in_grid]],
  5885. ushort tiisg[[thread_index_in_simdgroup]],
  5886. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  5887. kernel_mul_mv_q6_K_f32_impl<N_R0_Q6_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
  5888. }
  5889. // ======================= "True" 2-bit
  5890. template<int nr0, typename args_t>
  5891. void kernel_mul_mv_iq2_xxs_f32_impl(
  5892. args_t args,
  5893. device const char * src0,
  5894. device const char * src1,
  5895. device char * dst,
  5896. threadgroup char * shmem,
  5897. uint3 tgpig,
  5898. ushort tiisg,
  5899. ushort sgitg) {
  5900. const short NSG = FC_mul_mv_nsg;
  5901. const int nb = args.ne00/QK_K;
  5902. const int r0 = tgpig.x;
  5903. const int r1 = tgpig.y;
  5904. const int im = tgpig.z;
  5905. const int first_row = (r0 * NSG + sgitg) * nr0;
  5906. const uint i12 = im%args.ne12;
  5907. const uint i13 = im/args.ne12;
  5908. const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
  5909. const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
  5910. device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0);
  5911. device const float * y = (device const float *) (src1 + offset1);
  5912. float yl[32];
  5913. float sumf[nr0]={0.f};
  5914. const int nb32 = nb * (QK_K / 32);
  5915. threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem);
  5916. threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256);
  5917. {
  5918. int nval = 4;
  5919. int pos = (32*sgitg + tiisg)*nval;
  5920. for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xxs_grid[pos + i];
  5921. nval = 2;
  5922. pos = (32*sgitg + tiisg)*nval;
  5923. for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
  5924. threadgroup_barrier(mem_flags::mem_threadgroup);
  5925. }
  5926. const int ix = tiisg;
  5927. device const float * y4 = y + 32 * ix;
  5928. for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
  5929. for (short i = 0; i < 32; ++i) {
  5930. yl[i] = y4[i];
  5931. }
  5932. const int ibl = ib32 / (QK_K / 32);
  5933. const int ib = ib32 % (QK_K / 32);
  5934. device const block_iq2_xxs * xr = x + ibl;
  5935. device const uint16_t * q2 = xr->qs + 4 * ib;
  5936. device const half * dh = &xr->d;
  5937. for (short row = 0; row < nr0; row++) {
  5938. const float db = dh[0];
  5939. device const uint8_t * aux8 = (device const uint8_t *)q2;
  5940. const uint32_t aux32 = q2[2] | (q2[3] << 16);
  5941. const float d = db * (0.5f + (aux32 >> 28));
  5942. float sum = 0;
  5943. for (short l = 0; l < 4; ++l) {
  5944. const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]);
  5945. const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
  5946. for (short j = 0; j < 8; ++j) {
  5947. sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
  5948. }
  5949. }
  5950. sumf[row] += d * sum;
  5951. dh += args.nb01/2;
  5952. q2 += args.nb01/2;
  5953. }
  5954. y4 += 32 * 32;
  5955. }
  5956. device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
  5957. for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
  5958. float sum_all = simd_sum(sumf[row]);
  5959. if (tiisg == 0) {
  5960. dst_f32[first_row + row] = sum_all * 0.25f;
  5961. }
  5962. }
  5963. }
  5964. [[host_name("kernel_mul_mv_iq2_xxs_f32")]]
  5965. kernel void kernel_mul_mv_iq2_xxs_f32(
  5966. constant ggml_metal_kargs_mul_mv & args,
  5967. device const char * src0,
  5968. device const char * src1,
  5969. device char * dst,
  5970. threadgroup char * shmem [[threadgroup(0)]],
  5971. uint3 tgpig[[threadgroup_position_in_grid]],
  5972. ushort tiisg[[thread_index_in_simdgroup]],
  5973. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  5974. kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
  5975. }
  5976. template<int nr0, typename args_t>
  5977. void kernel_mul_mv_iq2_xs_f32_impl(
  5978. args_t args,
  5979. device const char * src0,
  5980. device const char * src1,
  5981. device char * dst,
  5982. threadgroup char * shmem,
  5983. uint3 tgpig,
  5984. ushort tiisg,
  5985. ushort sgitg) {
  5986. const short NSG = FC_mul_mv_nsg;
  5987. const int nb = args.ne00/QK_K;
  5988. const int r0 = tgpig.x;
  5989. const int r1 = tgpig.y;
  5990. const int im = tgpig.z;
  5991. const int first_row = (r0 * NSG + sgitg) * nr0;
  5992. const uint i12 = im%args.ne12;
  5993. const uint i13 = im/args.ne12;
  5994. const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
  5995. const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
  5996. device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0);
  5997. device const float * y = (device const float *) (src1 + offset1);
  5998. float yl[32];
  5999. float sumf[nr0]={0.f};
  6000. const int nb32 = nb * (QK_K / 32);
  6001. threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem);
  6002. threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 512);
  6003. {
  6004. int nval = 8;
  6005. int pos = (32*sgitg + tiisg)*nval;
  6006. for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xs_grid[pos + i];
  6007. nval = 2;
  6008. pos = (32*sgitg + tiisg)*nval;
  6009. for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
  6010. threadgroup_barrier(mem_flags::mem_threadgroup);
  6011. }
  6012. const int ix = tiisg;
  6013. device const float * y4 = y + 32 * ix;
  6014. for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
  6015. for (short i = 0; i < 32; ++i) {
  6016. yl[i] = y4[i];
  6017. }
  6018. const int ibl = ib32 / (QK_K / 32);
  6019. const int ib = ib32 % (QK_K / 32);
  6020. device const block_iq2_xs * xr = x + ibl;
  6021. device const uint16_t * q2 = xr->qs + 4 * ib;
  6022. device const uint8_t * sc = xr->scales + ib;
  6023. device const half * dh = &xr->d;
  6024. for (short row = 0; row < nr0; row++) {
  6025. const float db = dh[0];
  6026. const uint8_t ls1 = sc[0] & 0xf;
  6027. const uint8_t ls2 = sc[0] >> 4;
  6028. const float d1 = db * (0.5f + ls1);
  6029. const float d2 = db * (0.5f + ls2);
  6030. float sum1 = 0, sum2 = 0;
  6031. for (short l = 0; l < 2; ++l) {
  6032. const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
  6033. const uint8_t signs = ssigns[(q2[l] >> 9)];
  6034. for (short j = 0; j < 8; ++j) {
  6035. sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
  6036. }
  6037. }
  6038. for (short l = 2; l < 4; ++l) {
  6039. const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
  6040. const uint8_t signs = ssigns[(q2[l] >> 9)];
  6041. for (short j = 0; j < 8; ++j) {
  6042. sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
  6043. }
  6044. }
  6045. sumf[row] += d1 * sum1 + d2 * sum2;
  6046. dh += args.nb01/2;
  6047. q2 += args.nb01/2;
  6048. sc += args.nb01;
  6049. }
  6050. y4 += 32 * 32;
  6051. }
  6052. device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
  6053. for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
  6054. float sum_all = simd_sum(sumf[row]);
  6055. if (tiisg == 0) {
  6056. dst_f32[first_row + row] = sum_all * 0.25f;
  6057. }
  6058. }
  6059. }
  6060. [[host_name("kernel_mul_mv_iq2_xs_f32")]]
  6061. kernel void kernel_mul_mv_iq2_xs_f32(
  6062. constant ggml_metal_kargs_mul_mv & args,
  6063. device const char * src0,
  6064. device const char * src1,
  6065. device char * dst,
  6066. threadgroup char * shmem [[threadgroup(0)]],
  6067. uint3 tgpig[[threadgroup_position_in_grid]],
  6068. ushort tiisg[[thread_index_in_simdgroup]],
  6069. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  6070. kernel_mul_mv_iq2_xs_f32_impl<N_R0_IQ2_XS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
  6071. }
  6072. template<int nr0, typename args_t>
  6073. void kernel_mul_mv_iq3_xxs_f32_impl(
  6074. args_t args,
  6075. device const char * src0,
  6076. device const char * src1,
  6077. device char * dst,
  6078. threadgroup char * shmem,
  6079. uint3 tgpig,
  6080. ushort tiisg,
  6081. ushort sgitg) {
  6082. const short NSG = FC_mul_mv_nsg;
  6083. const int nb = args.ne00/QK_K;
  6084. const int r0 = tgpig.x;
  6085. const int r1 = tgpig.y;
  6086. const int im = tgpig.z;
  6087. const int first_row = (r0 * NSG + sgitg) * nr0;
  6088. const uint i12 = im%args.ne12;
  6089. const uint i13 = im/args.ne12;
  6090. const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
  6091. const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
  6092. device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0);
  6093. device const float * y = (device const float *) (src1 + offset1);
  6094. float yl[32];
  6095. float sumf[nr0]={0.f};
  6096. const int nb32 = nb * (QK_K / 32);
  6097. threadgroup uint32_t * svalues = (threadgroup uint32_t *)(shmem);
  6098. threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256);
  6099. {
  6100. int nval = 4;
  6101. int pos = (32*sgitg + tiisg)*nval;
  6102. for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3xxs_grid[pos + i];
  6103. nval = 2;
  6104. pos = (32*sgitg + tiisg)*nval;
  6105. for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
  6106. threadgroup_barrier(mem_flags::mem_threadgroup);
  6107. }
  6108. const int ix = tiisg;
  6109. device const float * y4 = y + 32 * ix;
  6110. for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
  6111. for (short i = 0; i < 32; ++i) {
  6112. yl[i] = y4[i];
  6113. }
  6114. const int ibl = ib32 / (QK_K / 32);
  6115. const int ib = ib32 % (QK_K / 32);
  6116. device const block_iq3_xxs * xr = x + ibl;
  6117. device const uint8_t * q3 = xr->qs + 8 * ib;
  6118. device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib;
  6119. device const half * dh = &xr->d;
  6120. for (short row = 0; row < nr0; row++) {
  6121. const float db = dh[0];
  6122. const uint32_t aux32 = gas[0] | (gas[1] << 16);
  6123. const float d = db * (0.5f + (aux32 >> 28));
  6124. float2 sum = {0};
  6125. for (short l = 0; l < 4; ++l) {
  6126. const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]);
  6127. const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]);
  6128. const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
  6129. for (short j = 0; j < 4; ++j) {
  6130. sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
  6131. sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
  6132. }
  6133. }
  6134. sumf[row] += d * (sum[0] + sum[1]);
  6135. dh += args.nb01/2;
  6136. q3 += args.nb01;
  6137. gas += args.nb01/2;
  6138. }
  6139. y4 += 32 * 32;
  6140. }
  6141. device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
  6142. for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
  6143. float sum_all = simd_sum(sumf[row]);
  6144. if (tiisg == 0) {
  6145. dst_f32[first_row + row] = sum_all * 0.5f;
  6146. }
  6147. }
  6148. }
  6149. [[host_name("kernel_mul_mv_iq3_xxs_f32")]]
  6150. kernel void kernel_mul_mv_iq3_xxs_f32(
  6151. constant ggml_metal_kargs_mul_mv & args,
  6152. device const char * src0,
  6153. device const char * src1,
  6154. device char * dst,
  6155. threadgroup char * shmem [[threadgroup(0)]],
  6156. uint3 tgpig[[threadgroup_position_in_grid]],
  6157. ushort tiisg[[thread_index_in_simdgroup]],
  6158. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  6159. kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
  6160. }
  6161. template<int nr0, typename args_t>
  6162. void kernel_mul_mv_iq3_s_f32_impl(
  6163. args_t args,
  6164. device const char * src0,
  6165. device const char * src1,
  6166. device char * dst,
  6167. threadgroup char * shmem,
  6168. uint3 tgpig,
  6169. ushort tiisg,
  6170. ushort sgitg) {
  6171. const short NSG = FC_mul_mv_nsg;
  6172. const int nb = args.ne00/QK_K;
  6173. const int r0 = tgpig.x;
  6174. const int r1 = tgpig.y;
  6175. const int im = tgpig.z;
  6176. const int first_row = (r0 * NSG + sgitg) * nr0;
  6177. const uint i12 = im%args.ne12;
  6178. const uint i13 = im/args.ne12;
  6179. const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
  6180. const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
  6181. device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0);
  6182. device const float * y = (device const float *) (src1 + offset1);
  6183. float yl[32];
  6184. float sumf[nr0]={0.f};
  6185. const int nb32 = nb * (QK_K / 32);
  6186. threadgroup uint32_t * svalues = (threadgroup uint32_t *) shmem;
  6187. {
  6188. int nval = 8;
  6189. int pos = (32*sgitg + tiisg)*nval;
  6190. for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3s_grid[pos + i];
  6191. threadgroup_barrier(mem_flags::mem_threadgroup);
  6192. }
  6193. const int ix = tiisg;
  6194. device const float * y4 = y + 32 * ix;
  6195. for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
  6196. for (short i = 0; i < 32; ++i) {
  6197. yl[i] = y4[i];
  6198. }
  6199. const int ibl = ib32 / (QK_K / 32);
  6200. const int ib = ib32 % (QK_K / 32);
  6201. device const block_iq3_s * xr = x + ibl;
  6202. device const uint8_t * qs = xr->qs + 8 * ib;
  6203. device const uint8_t * qh = xr->qh + ib;
  6204. device const uint8_t * sc = xr->scales + (ib/2);
  6205. device const uint8_t * signs = xr->signs + 4 * ib;
  6206. device const half * dh = &xr->d;
  6207. for (short row = 0; row < nr0; row++) {
  6208. const float db = dh[0];
  6209. const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
  6210. float2 sum = {0};
  6211. for (short l = 0; l < 4; ++l) {
  6212. const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues;
  6213. const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues;
  6214. const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
  6215. const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
  6216. for (short j = 0; j < 4; ++j) {
  6217. sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
  6218. sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
  6219. }
  6220. }
  6221. sumf[row] += d * (sum[0] + sum[1]);
  6222. dh += args.nb01/2;
  6223. qs += args.nb01;
  6224. qh += args.nb01;
  6225. sc += args.nb01;
  6226. signs += args.nb01;
  6227. }
  6228. y4 += 32 * 32;
  6229. }
  6230. device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
  6231. for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
  6232. float sum_all = simd_sum(sumf[row]);
  6233. if (tiisg == 0) {
  6234. dst_f32[first_row + row] = sum_all;
  6235. }
  6236. }
  6237. }
  6238. [[host_name("kernel_mul_mv_iq3_s_f32")]]
  6239. kernel void kernel_mul_mv_iq3_s_f32(
  6240. constant ggml_metal_kargs_mul_mv & args,
  6241. device const char * src0,
  6242. device const char * src1,
  6243. device char * dst,
  6244. threadgroup char * shmem [[threadgroup(0)]],
  6245. uint3 tgpig[[threadgroup_position_in_grid]],
  6246. ushort tiisg[[thread_index_in_simdgroup]],
  6247. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  6248. kernel_mul_mv_iq3_s_f32_impl<N_R0_IQ3_S, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
  6249. }
  6250. template<int nr0, typename args_t>
  6251. void kernel_mul_mv_iq2_s_f32_impl(
  6252. args_t args,
  6253. device const char * src0,
  6254. device const char * src1,
  6255. device char * dst,
  6256. threadgroup char * shmem,
  6257. uint3 tgpig,
  6258. ushort tiisg,
  6259. ushort sgitg) {
  6260. const short NSG = FC_mul_mv_nsg;
  6261. const int nb = args.ne00/QK_K;
  6262. const int r0 = tgpig.x;
  6263. const int r1 = tgpig.y;
  6264. const int im = tgpig.z;
  6265. const int first_row = (r0 * NSG + sgitg) * nr0;
  6266. const uint i12 = im%args.ne12;
  6267. const uint i13 = im/args.ne12;
  6268. const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
  6269. const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
  6270. device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0);
  6271. device const float * y = (device const float *) (src1 + offset1);
  6272. float yl[32];
  6273. float sumf[nr0]={0.f};
  6274. const int nb32 = nb * (QK_K / 32);
  6275. //threadgroup uint64_t * svalues = (threadgroup uint64_t *) shmem;
  6276. //{
  6277. // int nval = 32;
  6278. // int pos = (32*sgitg + tiisg)*nval;
  6279. // for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2s_grid[pos + i];
  6280. // threadgroup_barrier(mem_flags::mem_threadgroup);
  6281. //}
  6282. const short ix = tiisg;
  6283. device const float * y4 = y + 32 * ix;
  6284. for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
  6285. for (short i = 0; i < 32; ++i) {
  6286. yl[i] = y4[i];
  6287. }
  6288. const int ibl = ib32 / (QK_K / 32);
  6289. const int ib = ib32 % (QK_K / 32);
  6290. device const block_iq2_s * xr = x + ibl;
  6291. device const uint8_t * qs = xr->qs + 4 * ib;
  6292. device const uint8_t * qh = xr->qh + ib;
  6293. device const uint8_t * sc = xr->scales + ib;
  6294. device const uint8_t * signs = qs + QK_K/8;
  6295. device const half * dh = &xr->d;
  6296. for (short row = 0; row < nr0; row++) {
  6297. const float db = dh[0];
  6298. const float d1 = db * (0.5f + (sc[0] & 0xf));
  6299. const float d2 = db * (0.5f + (sc[0] >> 4));
  6300. float2 sum = {0};
  6301. for (short l = 0; l < 2; ++l) {
  6302. //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
  6303. //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
  6304. constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
  6305. constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
  6306. for (short j = 0; j < 8; ++j) {
  6307. sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]);
  6308. sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]);
  6309. }
  6310. }
  6311. sumf[row] += d1 * sum[0] + d2 * sum[1];
  6312. dh += args.nb01/2;
  6313. qs += args.nb01;
  6314. qh += args.nb01;
  6315. sc += args.nb01;
  6316. signs += args.nb01;
  6317. }
  6318. y4 += 32 * 32;
  6319. }
  6320. device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
  6321. for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
  6322. float sum_all = simd_sum(sumf[row]);
  6323. if (tiisg == 0) {
  6324. dst_f32[first_row + row] = sum_all * 0.25f;
  6325. }
  6326. }
  6327. }
  6328. [[host_name("kernel_mul_mv_iq2_s_f32")]]
  6329. kernel void kernel_mul_mv_iq2_s_f32(
  6330. constant ggml_metal_kargs_mul_mv & args,
  6331. device const char * src0,
  6332. device const char * src1,
  6333. device char * dst,
  6334. threadgroup char * shmem [[threadgroup(0)]],
  6335. uint3 tgpig[[threadgroup_position_in_grid]],
  6336. ushort tiisg[[thread_index_in_simdgroup]],
  6337. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  6338. kernel_mul_mv_iq2_s_f32_impl<N_R0_IQ2_S, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
  6339. }
  6340. template<int nr0, typename args_t>
  6341. void kernel_mul_mv_iq1_s_f32_impl(
  6342. args_t args,
  6343. device const char * src0,
  6344. device const char * src1,
  6345. device char * dst,
  6346. threadgroup char * shmem,
  6347. uint3 tgpig,
  6348. ushort tiisg,
  6349. ushort sgitg) {
  6350. const short NSG = FC_mul_mv_nsg;
  6351. const int nb = args.ne00/QK_K;
  6352. const int r0 = tgpig.x;
  6353. const int r1 = tgpig.y;
  6354. const int im = tgpig.z;
  6355. const int first_row = (r0 * NSG + sgitg) * nr0;
  6356. const uint i12 = im%args.ne12;
  6357. const uint i13 = im/args.ne12;
  6358. const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
  6359. const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
  6360. device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0);
  6361. device const float * y = (device const float *) (src1 + offset1);
  6362. float yl[32];
  6363. float sumf[nr0]={0.f};
  6364. const int nb32 = nb * (QK_K / 32);
  6365. const short ix = tiisg;
  6366. device const float * y4 = y + 32 * ix;
  6367. for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
  6368. float sumy = 0;
  6369. for (short i = 0; i < 32; ++i) {
  6370. yl[i] = y4[i];
  6371. sumy += yl[i];
  6372. }
  6373. const int ibl = ib32 / (QK_K / 32);
  6374. const int ib = ib32 % (QK_K / 32);
  6375. device const block_iq1_s * xr = x + ibl;
  6376. device const uint8_t * qs = xr->qs + 4 * ib;
  6377. device const uint16_t * qh = xr->qh + ib;
  6378. device const half * dh = &xr->d;
  6379. for (short row = 0; row < nr0; row++) {
  6380. constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
  6381. constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700)));
  6382. constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700)));
  6383. constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700)));
  6384. float sum = 0;
  6385. for (short j = 0; j < 4; ++j) {
  6386. sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
  6387. + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4)
  6388. + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
  6389. + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
  6390. }
  6391. sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1);
  6392. dh += args.nb01/2;
  6393. qs += args.nb01;
  6394. qh += args.nb01/2;
  6395. }
  6396. y4 += 32 * 32;
  6397. }
  6398. device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
  6399. for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
  6400. float sum_all = simd_sum(sumf[row]);
  6401. if (tiisg == 0) {
  6402. dst_f32[first_row + row] = sum_all;
  6403. }
  6404. }
  6405. }
  6406. [[host_name("kernel_mul_mv_iq1_s_f32")]]
  6407. kernel void kernel_mul_mv_iq1_s_f32(
  6408. constant ggml_metal_kargs_mul_mv & args,
  6409. device const char * src0,
  6410. device const char * src1,
  6411. device char * dst,
  6412. uint3 tgpig[[threadgroup_position_in_grid]],
  6413. ushort tiisg[[thread_index_in_simdgroup]],
  6414. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  6415. kernel_mul_mv_iq1_s_f32_impl<N_R0_IQ1_S, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
  6416. }
  6417. template<int nr0, typename args_t>
  6418. void kernel_mul_mv_iq1_m_f32_impl(
  6419. args_t args,
  6420. device const char * src0,
  6421. device const char * src1,
  6422. device char * dst,
  6423. threadgroup char * shmem,
  6424. uint3 tgpig,
  6425. ushort tiisg,
  6426. ushort sgitg) {
  6427. const short NSG = FC_mul_mv_nsg;
  6428. const int nb = args.ne00/QK_K;
  6429. const int r0 = tgpig.x;
  6430. const int r1 = tgpig.y;
  6431. const int im = tgpig.z;
  6432. const int first_row = (r0 * NSG + sgitg) * nr0;
  6433. const uint i12 = im%args.ne12;
  6434. const uint i13 = im/args.ne12;
  6435. const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
  6436. const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
  6437. device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0);
  6438. device const float * y = (device const float *) (src1 + offset1);
  6439. float yl[32];
  6440. float sumf[nr0]={0.f};
  6441. const int nb32 = nb * (QK_K / 32);
  6442. const short ix = tiisg;
  6443. device const float * y4 = y + 32 * ix;
  6444. iq1m_scale_t scale;
  6445. for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
  6446. float4 sumy = {0.f};
  6447. for (short i = 0; i < 8; ++i) {
  6448. yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
  6449. yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
  6450. yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
  6451. yl[i+24] = y4[i+24]; sumy[3] += yl[i+24];
  6452. }
  6453. const int ibl = ib32 / (QK_K / 32);
  6454. const int ib = ib32 % (QK_K / 32);
  6455. device const block_iq1_m * xr = x + ibl;
  6456. device const uint8_t * qs = xr->qs + 4 * ib;
  6457. device const uint8_t * qh = xr->qh + 2 * ib;
  6458. device const uint16_t * sc = (device const uint16_t *)xr->scales;
  6459. for (short row = 0; row < nr0; row++) {
  6460. scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
  6461. constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
  6462. constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
  6463. constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700)));
  6464. constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
  6465. float2 sum = {0.f};
  6466. for (short j = 0; j < 4; ++j) {
  6467. sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
  6468. + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
  6469. sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
  6470. + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
  6471. }
  6472. const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
  6473. const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
  6474. sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
  6475. (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
  6476. sc += args.nb01/2;
  6477. qs += args.nb01;
  6478. qh += args.nb01;
  6479. }
  6480. y4 += 32 * 32;
  6481. }
  6482. device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
  6483. for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
  6484. float sum_all = simd_sum(sumf[row]);
  6485. if (tiisg == 0) {
  6486. dst_f32[first_row + row] = sum_all;
  6487. }
  6488. }
  6489. }
  6490. [[host_name("kernel_mul_mv_iq1_m_f32")]]
  6491. kernel void kernel_mul_mv_iq1_m_f32(
  6492. constant ggml_metal_kargs_mul_mv & args,
  6493. device const char * src0,
  6494. device const char * src1,
  6495. device char * dst,
  6496. uint3 tgpig[[threadgroup_position_in_grid]],
  6497. ushort tiisg[[thread_index_in_simdgroup]],
  6498. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  6499. kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
  6500. }
  6501. template<int NR0, typename args_t>
  6502. void kernel_mul_mv_iq4_nl_f32_impl(
  6503. args_t args,
  6504. device const char * src0,
  6505. device const char * src1,
  6506. device char * dst,
  6507. threadgroup char * shmem,
  6508. uint3 tgpig,
  6509. ushort tiisg,
  6510. ushort sgitg) {
  6511. const short NSG = FC_mul_mv_nsg;
  6512. threadgroup float * shmem_f32 = (threadgroup float *) shmem;
  6513. const int r0 = tgpig.x;
  6514. const int r1 = tgpig.y;
  6515. const int im = tgpig.z;
  6516. const int first_row = (r0 * NSG + sgitg) * NR0;
  6517. const uint i12 = im%args.ne12;
  6518. const uint i13 = im/args.ne12;
  6519. const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
  6520. const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
  6521. device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
  6522. device const float * y = (device const float *) (src1 + offset1);
  6523. const int nb = args.ne00/QK4_NL;
  6524. const int ns01 = args.nb01/args.nb00;
  6525. const short ix = tiisg/2; // 0...15
  6526. const short it = tiisg%2; // 0 or 1
  6527. shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
  6528. threadgroup_barrier(mem_flags::mem_threadgroup);
  6529. float4 yl[4];
  6530. float sumf[NR0]={0.f};
  6531. device const float * yb = y + ix*QK4_NL + it*8;
  6532. uint32_t aux32[2];
  6533. thread const uint8_t * q8 = (thread const uint8_t *)aux32;
  6534. float4 qf1, qf2;
  6535. // [TAG_MUL_MV_WEIRD]
  6536. for (int ib = ix; ib < nb && ib < ns01; ib += 16) {
  6537. device const float4 * y4 = (device const float4 *)yb;
  6538. yl[0] = y4[0];
  6539. yl[1] = y4[4];
  6540. yl[2] = y4[1];
  6541. yl[3] = y4[5];
  6542. for (short row = 0; row < NR0; row++) {
  6543. device const block_iq4_nl & xb = x[row*ns01 + ib];
  6544. device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
  6545. float4 acc1 = {0.f}, acc2 = {0.f};
  6546. aux32[0] = q4[0] | (q4[1] << 16);
  6547. aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
  6548. aux32[0] &= 0x0f0f0f0f;
  6549. qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
  6550. qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
  6551. acc1 += yl[0] * qf1;
  6552. acc2 += yl[1] * qf2;
  6553. aux32[0] = q4[2] | (q4[3] << 16);
  6554. aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
  6555. aux32[0] &= 0x0f0f0f0f;
  6556. qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
  6557. qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
  6558. acc1 += yl[2] * qf1;
  6559. acc2 += yl[3] * qf2;
  6560. acc1 += acc2;
  6561. sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
  6562. }
  6563. yb += 16 * QK4_NL;
  6564. }
  6565. device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
  6566. for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
  6567. float sum_all = simd_sum(sumf[row]);
  6568. if (tiisg == 0) {
  6569. dst_f32[first_row + row] = sum_all;
  6570. }
  6571. }
  6572. }
  6573. [[host_name("kernel_mul_mv_iq4_nl_f32")]]
  6574. kernel void kernel_mul_mv_iq4_nl_f32(
  6575. constant ggml_metal_kargs_mul_mv & args,
  6576. device const char * src0,
  6577. device const char * src1,
  6578. device char * dst,
  6579. threadgroup char * shmem [[threadgroup(0)]],
  6580. uint3 tgpig[[threadgroup_position_in_grid]],
  6581. ushort tiisg[[thread_index_in_simdgroup]],
  6582. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  6583. kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
  6584. }
  6585. template<int NR0, typename args_t>
  6586. void kernel_mul_mv_iq4_xs_f32_impl(
  6587. args_t args,
  6588. device const char * src0,
  6589. device const char * src1,
  6590. device char * dst,
  6591. threadgroup char * shmem,
  6592. uint3 tgpig,
  6593. ushort tiisg,
  6594. ushort sgitg) {
  6595. const short NSG = FC_mul_mv_nsg;
  6596. threadgroup float * shmem_f32 = (threadgroup float *) shmem;
  6597. const int r0 = tgpig.x;
  6598. const int r1 = tgpig.y;
  6599. const int im = tgpig.z;
  6600. const int first_row = (r0 * NSG + sgitg) * NR0;
  6601. const uint i12 = im%args.ne12;
  6602. const uint i13 = im/args.ne12;
  6603. const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
  6604. const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
  6605. device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
  6606. device const float * y = (device const float *) (src1 + offset1);
  6607. const int nb = args.ne00/QK_K;
  6608. const int ns01 = args.nb01/args.nb00;
  6609. const short ix = tiisg/16; // 0 or 1
  6610. const short it = tiisg%16; // 0...15
  6611. const short ib = it/2;
  6612. const short il = it%2;
  6613. shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
  6614. threadgroup_barrier(mem_flags::mem_threadgroup);
  6615. float4 yl[4];
  6616. float sumf[NR0]={0.f};
  6617. device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
  6618. uint32_t aux32[2];
  6619. thread const uint8_t * q8 = (thread const uint8_t *)aux32;
  6620. float4 qf1, qf2;
  6621. // [TAG_MUL_MV_WEIRD]
  6622. for (int ibl = ix; ibl < nb && ibl < ns01; ibl += 2) {
  6623. device const float4 * y4 = (device const float4 *)yb;
  6624. yl[0] = y4[0];
  6625. yl[1] = y4[4];
  6626. yl[2] = y4[1];
  6627. yl[3] = y4[5];
  6628. for (short row = 0; row < NR0; ++row) {
  6629. device const block_iq4_xs & xb = x[row*ns01 + ibl];
  6630. device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
  6631. float4 acc1 = {0.f}, acc2 = {0.f};
  6632. aux32[0] = (q4[0] ) & 0x0f0f0f0f;
  6633. aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
  6634. qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
  6635. qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
  6636. acc1 += yl[0] * qf1;
  6637. acc2 += yl[1] * qf2;
  6638. aux32[0] = (q4[1] ) & 0x0f0f0f0f;
  6639. aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
  6640. qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
  6641. qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
  6642. acc1 += yl[2] * qf1;
  6643. acc2 += yl[3] * qf2;
  6644. acc1 += acc2;
  6645. const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
  6646. sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
  6647. }
  6648. yb += 2 * QK_K;
  6649. }
  6650. device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
  6651. for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
  6652. float sum_all = simd_sum(sumf[row]);
  6653. if (tiisg == 0) {
  6654. dst_f32[first_row + row] = sum_all;
  6655. }
  6656. }
  6657. }
  6658. [[host_name("kernel_mul_mv_iq4_xs_f32")]]
  6659. kernel void kernel_mul_mv_iq4_xs_f32(
  6660. constant ggml_metal_kargs_mul_mv & args,
  6661. device const char * src0,
  6662. device const char * src1,
  6663. device char * dst,
  6664. threadgroup char * shmem [[threadgroup(0)]],
  6665. uint3 tgpig[[threadgroup_position_in_grid]],
  6666. ushort tiisg[[thread_index_in_simdgroup]],
  6667. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  6668. kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
  6669. }
  6670. template<int NR0, typename args_t>
  6671. void kernel_mul_mv_mxfp4_f32_impl(
  6672. args_t args,
  6673. device const char * src0,
  6674. device const char * src1,
  6675. device char * dst,
  6676. threadgroup char * shmem,
  6677. uint3 tgpig,
  6678. ushort tiisg,
  6679. ushort sgitg) {
  6680. const short NSG = FC_mul_mv_nsg;
  6681. threadgroup float * shmem_f32 = (threadgroup float *) shmem;
  6682. const int r0 = tgpig.x;
  6683. const int r1 = tgpig.y;
  6684. const int im = tgpig.z;
  6685. const int first_row = (r0 * NSG + sgitg) * NR0;
  6686. const uint i12 = im%args.ne12;
  6687. const uint i13 = im/args.ne12;
  6688. const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
  6689. const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
  6690. device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0);
  6691. device const float * y = (device const float *) (src1 + offset1);
  6692. const int nb = args.ne00/QK_MXFP4;
  6693. const int ns01 = args.nb01/args.nb00; // this can be larger than nb for permuted src0 tensors
  6694. const short ix = tiisg/2; // 0...15
  6695. const short it = tiisg%2; // 0 or 1
  6696. shmem_f32[tiisg] = kvalues_mxfp4_f[tiisg%16];
  6697. threadgroup_barrier(mem_flags::mem_threadgroup);
  6698. float4 yl[4];
  6699. float sumf[NR0]={0.f};
  6700. device const float * yb = y + ix*QK_MXFP4 + it*8;
  6701. // note: just the check `ib < nb` is enough, but adding the redundant `&& ib < ns01` check makes the kernel a bit faster
  6702. // no idea why that is - needs some deeper investigation [TAG_MUL_MV_WEIRD]
  6703. for (int ib = ix; ib < nb && ib < ns01; ib += 16) {
  6704. device const float4 * y4 = (device const float4 *) yb;
  6705. yl[0] = y4[0];
  6706. yl[1] = y4[4];
  6707. yl[2] = y4[1];
  6708. yl[3] = y4[5];
  6709. FOR_UNROLL (short row = 0; row < NR0; row++) {
  6710. device const block_mxfp4 & xb = x[row*ns01 + ib];
  6711. device const uint8_t * q2 = (device const uint8_t *)(xb.qs + 8*it);
  6712. float4 acc1 = yl[0]*float4(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]);
  6713. float4 acc2 = yl[1]*float4(shmem_f32[q2[0] >> 4 ], shmem_f32[q2[1] >> 4 ], shmem_f32[q2[2] >> 4 ], shmem_f32[q2[3] >> 4 ]);
  6714. float4 acc3 = yl[2]*float4(shmem_f32[q2[4] & 0x0F], shmem_f32[q2[5] & 0x0F], shmem_f32[q2[6] & 0x0F], shmem_f32[q2[7] & 0x0F]);
  6715. float4 acc4 = yl[3]*float4(shmem_f32[q2[4] >> 4 ], shmem_f32[q2[5] >> 4 ], shmem_f32[q2[6] >> 4 ], shmem_f32[q2[7] >> 4 ]);
  6716. acc1 = (acc1 + acc3) + (acc2 + acc4);
  6717. sumf[row] += e8m0_to_fp32(xb.e) * ((acc1[0] + acc1[1]) + (acc1[2] + acc1[3]));
  6718. }
  6719. yb += 16 * QK_MXFP4;
  6720. }
  6721. device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
  6722. for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
  6723. float sum_all = simd_sum(sumf[row]);
  6724. if (tiisg == 0) {
  6725. dst_f32[first_row + row] = sum_all;
  6726. }
  6727. }
  6728. }
  6729. [[host_name("kernel_mul_mv_mxfp4_f32")]]
  6730. kernel void kernel_mul_mv_mxfp4_f32(
  6731. constant ggml_metal_kargs_mul_mv & args,
  6732. device const char * src0,
  6733. device const char * src1,
  6734. device char * dst,
  6735. threadgroup char * shmem [[threadgroup(0)]],
  6736. uint3 tgpig[[threadgroup_position_in_grid]],
  6737. ushort tiisg[[thread_index_in_simdgroup]],
  6738. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  6739. kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
  6740. }
  6741. template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
  6742. kernel void kernel_get_rows_q(
  6743. constant ggml_metal_kargs_get_rows & args,
  6744. device const void * src0,
  6745. device const void * src1,
  6746. device void * dst,
  6747. uint3 tgpig[[threadgroup_position_in_grid]],
  6748. ushort tiitg[[thread_index_in_threadgroup]],
  6749. ushort3 ntg [[threads_per_threadgroup]]) {
  6750. const int32_t iw0 = tgpig.x/args.ne10;
  6751. const int32_t i10 = tgpig.x%args.ne10;
  6752. const int32_t i11 = tgpig.y;
  6753. const int32_t i12 = tgpig.z;
  6754. const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
  6755. const int32_t i02 = i11;
  6756. const int32_t i03 = i12;
  6757. auto psrc = (device const block_q *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
  6758. auto pdst = (device float4x4 *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
  6759. for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
  6760. float4x4 temp;
  6761. dequantize_func(psrc + ind/nl, ind%nl, temp);
  6762. pdst[ind] = temp;
  6763. break;
  6764. }
  6765. }
  6766. template<typename T0, typename T>
  6767. kernel void kernel_get_rows_f(
  6768. constant ggml_metal_kargs_get_rows & args,
  6769. device const void * src0,
  6770. device const void * src1,
  6771. device void * dst,
  6772. uint3 tgpig[[threadgroup_position_in_grid]],
  6773. ushort tiitg[[thread_index_in_threadgroup]],
  6774. ushort3 ntg [[threads_per_threadgroup]]) {
  6775. const int32_t iw0 = tgpig.x/args.ne10;
  6776. const int32_t i10 = tgpig.x%args.ne10;
  6777. const int32_t i11 = tgpig.y;
  6778. const int32_t i12 = tgpig.z;
  6779. const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
  6780. const int32_t i02 = i11;
  6781. const int32_t i03 = i12;
  6782. auto psrc = (const device T0 *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
  6783. auto pdst = ( device T *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
  6784. for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
  6785. pdst[ind] = psrc[ind];
  6786. break;
  6787. }
  6788. }
  6789. template<typename TI, typename block_q, void (*quantize_func)(device const float *, device block_q &)>
  6790. kernel void kernel_set_rows_q32(
  6791. constant ggml_metal_kargs_set_rows & args,
  6792. device const void * src0,
  6793. device const void * src1,
  6794. device float * dst,
  6795. uint3 tgpig[[threadgroup_position_in_grid]],
  6796. uint tiitg[[thread_index_in_threadgroup]],
  6797. uint3 tptg [[threads_per_threadgroup]]) {
  6798. const int32_t i03 = tgpig.z;
  6799. const int32_t i02 = tgpig.y;
  6800. const int32_t i12 = i03%args.ne12;
  6801. const int32_t i11 = i02%args.ne11;
  6802. const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
  6803. if (i01 >= args.ne01) {
  6804. return;
  6805. }
  6806. const int32_t i10 = i01;
  6807. const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
  6808. device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
  6809. const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
  6810. for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
  6811. quantize_func(src_row + 32*ind, dst_row[ind]);
  6812. }
  6813. }
  6814. template<typename T, typename TI>
  6815. kernel void kernel_set_rows_f(
  6816. constant ggml_metal_kargs_set_rows & args,
  6817. device const void * src0,
  6818. device const void * src1,
  6819. device float * dst,
  6820. uint3 tgpig[[threadgroup_position_in_grid]],
  6821. uint tiitg[[thread_index_in_threadgroup]],
  6822. uint3 tptg [[threads_per_threadgroup]]) {
  6823. const int32_t i03 = tgpig.z;
  6824. const int32_t i02 = tgpig.y;
  6825. const int32_t i12 = i03%args.ne12;
  6826. const int32_t i11 = i02%args.ne11;
  6827. const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
  6828. if (i01 >= args.ne01) {
  6829. return;
  6830. }
  6831. const int32_t i10 = i01;
  6832. const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
  6833. device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
  6834. const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
  6835. for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
  6836. dst_row[ind] = (T) src_row[ind];
  6837. }
  6838. }
  6839. constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
  6840. constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];
  6841. // each block_q contains 16*nl weights
  6842. template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
  6843. kernel void kernel_mul_mm(
  6844. constant ggml_metal_kargs_mul_mm & args,
  6845. device const char * src0,
  6846. device const char * src1,
  6847. device char * dst,
  6848. threadgroup char * shmem [[threadgroup(0)]],
  6849. uint3 tgpig[[threadgroup_position_in_grid]],
  6850. ushort tiitg[[thread_index_in_threadgroup]],
  6851. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  6852. threadgroup S0 * sa = (threadgroup S0 *)(shmem);
  6853. threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
  6854. threadgroup float * sc = (threadgroup float *)(shmem);
  6855. constexpr int NR0 = 64;
  6856. constexpr int NR1 = 32;
  6857. constexpr int NK = 32;
  6858. constexpr int NL0 = NK/16;
  6859. constexpr int NL1 = NK/8;
  6860. const int im = tgpig.z;
  6861. const int r0 = tgpig.y*NR0;
  6862. const int r1 = tgpig.x*NR1;
  6863. // if this block is of 64x32 shape or smaller
  6864. const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;
  6865. const short nr1 = (args.ne1 - r1 < NR1) ? (args.ne1 - r1) : NR1;
  6866. // a thread shouldn't load data outside of the matrix
  6867. const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
  6868. const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
  6869. const short il0 = (tiitg % NL0);
  6870. short il = il0;
  6871. const int i12 = im%args.ne12;
  6872. const int i13 = im/args.ne12;
  6873. const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
  6874. const short offset1 = il0/nl;
  6875. device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
  6876. const short iy = 8*(tiitg % NL1);
  6877. device const T1 * y = (device const T1 *)(src1
  6878. + args.nb13*i13
  6879. + args.nb12*i12
  6880. + args.nb11*(r1 + lr1)
  6881. + args.nb10*iy);
  6882. #ifndef GGML_METAL_HAS_TENSOR
  6883. S0_8x8 ma[4];
  6884. S1_8x8 mb[2];
  6885. simdgroup_float8x8 mc[8];
  6886. for (short i = 0; i < 8; i++){
  6887. mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
  6888. }
  6889. #else
  6890. auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0));
  6891. auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));
  6892. mpp::tensor_ops::matmul2d<
  6893. mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
  6894. execution_simdgroups<4>> mm;
  6895. auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();
  6896. #endif
  6897. for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
  6898. #ifndef GGML_METAL_HAS_TENSOR
  6899. // load data and store to threadgroup memory
  6900. if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
  6901. threadgroup_barrier(mem_flags::mem_threadgroup);
  6902. // no need for dequantization
  6903. for (short i = 0; i < 16; i++) {
  6904. const short sx = 2*il0 + i/8;
  6905. const short sy = (tiitg/NL0)/8;
  6906. //const short lx = i%8;
  6907. //const short ly = (tiitg/NL0)%8;
  6908. const short lx = (tiitg/NL0)%8;
  6909. const short ly = i%8;
  6910. const short ib = 8*sx + sy;
  6911. *(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
  6912. }
  6913. } else {
  6914. S0_4x4 temp_a;
  6915. dequantize_func(x, il, temp_a);
  6916. threadgroup_barrier(mem_flags::mem_threadgroup);
  6917. FOR_UNROLL (short i = 0; i < 16; i++) {
  6918. const short sx = 2*il0 + i/8;
  6919. const short sy = (tiitg/NL0)/8;
  6920. //const short lx = i%8;
  6921. //const short ly = (tiitg/NL0)%8;
  6922. const short lx = (tiitg/NL0)%8;
  6923. const short ly = i%8;
  6924. const short ib = 8*sx + sy;
  6925. // NOTE: this is massively slower.. WTF?
  6926. //sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];
  6927. *(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4];
  6928. }
  6929. }
  6930. if (FC_mul_mm_bc_inp) {
  6931. for (short i = 0; i < 8; ++i) {
  6932. const short sx = (tiitg%NL1);
  6933. const short sy = (tiitg/NL1)/8;
  6934. const short lx = i;
  6935. const short ly = (tiitg/NL1)%8;
  6936. //const short lx = (tiitg/NL1)%8;
  6937. //const short ly = i;
  6938. const short ib = 4*sx + sy;
  6939. *(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
  6940. }
  6941. } else {
  6942. const short sx = (tiitg%NL1);
  6943. const short sy = (tiitg/NL1)/8;
  6944. const short dx = sx;
  6945. const short dy = sy;
  6946. const short ly = (tiitg/NL1)%8;
  6947. const short ib = 4*sx + sy;
  6948. *(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
  6949. }
  6950. #else
  6951. // load data and store to threadgroup memory
  6952. if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
  6953. threadgroup_barrier(mem_flags::mem_threadgroup);
  6954. // no need for dequantization
  6955. for (short i = 0; i < 16; i++) {
  6956. const short sx = 2*il0 + i/8;
  6957. const short sy = (tiitg/NL0)/8;
  6958. const short lx = i%8;
  6959. const short ly = (tiitg/NL0)%8;
  6960. //const short lx = (tiitg/NL0)%8;
  6961. //const short ly = i%8;
  6962. *(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
  6963. }
  6964. } else {
  6965. S0_4x4 temp_a;
  6966. dequantize_func(x, il, temp_a);
  6967. threadgroup_barrier(mem_flags::mem_threadgroup);
  6968. FOR_UNROLL (short i = 0; i < 16; i++) {
  6969. const short sx = 2*il0 + i/8;
  6970. const short sy = (tiitg/NL0)/8;
  6971. const short lx = i%8;
  6972. const short ly = (tiitg/NL0)%8;
  6973. //const short lx = (tiitg/NL0)%8;
  6974. //const short ly = i%8;
  6975. *(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4];
  6976. }
  6977. }
  6978. if (FC_mul_mm_bc_inp) {
  6979. for (short i = 0; i < 8; ++i) {
  6980. const short sx = (tiitg%NL1);
  6981. const short sy = (tiitg/NL1)/8;
  6982. const short lx = i;
  6983. const short ly = (tiitg/NL1)%8;
  6984. //const short lx = (tiitg/NL1)%8;
  6985. //const short ly = i;
  6986. *(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
  6987. }
  6988. } else {
  6989. const short sx = (tiitg%NL1);
  6990. const short sy = (tiitg/NL1)/8;
  6991. //const short lx = i;
  6992. const short ly = (tiitg/NL1)%8;
  6993. //const short lx = (tiitg/NL1)%8;
  6994. //const short ly = i;
  6995. *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));
  6996. }
  6997. #endif
  6998. il = (il + 2 < nl) ? il + 2 : il % 2;
  6999. x = (il < 2) ? x + (2 + nl - 1)/nl : x;
  7000. y += NK;
  7001. threadgroup_barrier(mem_flags::mem_threadgroup);
  7002. #ifndef GGML_METAL_HAS_TENSOR
  7003. // load matrices from threadgroup memory and conduct outer products
  7004. threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
  7005. threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
  7006. FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
  7007. simdgroup_barrier(mem_flags::mem_none);
  7008. FOR_UNROLL (short i = 0; i < 4; i++) {
  7009. simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
  7010. }
  7011. simdgroup_barrier(mem_flags::mem_none);
  7012. FOR_UNROLL (short i = 0; i < 2; i++) {
  7013. simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
  7014. }
  7015. simdgroup_barrier(mem_flags::mem_none);
  7016. FOR_UNROLL (short i = 0; i < 8; i++){
  7017. simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
  7018. }
  7019. lsma += 8*64;
  7020. lsmb += 4*64;
  7021. }
  7022. #else
  7023. auto sA = tA.slice(0, 0);
  7024. auto sB = tB.slice(0, 0);
  7025. mm.run(sB, sA, cT);
  7026. #endif
  7027. }
  7028. if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) {
  7029. // if no bounds checks on the output are needed, we can directly write to device memory
  7030. #ifdef GGML_METAL_HAS_TENSOR
  7031. device float * C = (device float *) dst +
  7032. r0 + \
  7033. r1 * args.ne0 + im*args.ne1*args.ne0;
  7034. auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(args.ne0, NR1));
  7035. cT.store(tC);
  7036. #else
  7037. device float * C = (device float *) dst +
  7038. (r0 + 32*(sgitg & 1)) + \
  7039. (r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
  7040. for (short i = 0; i < 8; i++) {
  7041. simdgroup_store(mc[i], C + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false);
  7042. }
  7043. #endif
  7044. } else {
  7045. // block is smaller than 64x32, we should avoid writing data outside of the matrix
  7046. threadgroup_barrier(mem_flags::mem_threadgroup);
  7047. threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
  7048. #ifdef GGML_METAL_HAS_TENSOR
  7049. auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));
  7050. cT.store(tC);
  7051. #else
  7052. for (short i = 0; i < 8; i++) {
  7053. simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
  7054. }
  7055. #endif
  7056. threadgroup_barrier(mem_flags::mem_threadgroup);
  7057. if (sgitg == 0) {
  7058. for (int j = tiitg; j < nr1; j += NR1) {
  7059. device float * D = (device float *) dst + r0 + (r1 + j)*args.ne0 + im*args.ne1*args.ne0;
  7060. device float4 * D4 = (device float4 *) D;
  7061. threadgroup float * C = temp_str + (j*NR0);
  7062. threadgroup float4 * C4 = (threadgroup float4 *) C;
  7063. int i = 0;
  7064. for (; i < nr0/4; i++) {
  7065. *(D4 + i) = *(C4 + i);
  7066. }
  7067. i *= 4;
  7068. for (; i < nr0; i++) {
  7069. *(D + i) = *(C + i);
  7070. }
  7071. }
  7072. }
  7073. }
  7074. }
  7075. template<short ne20> // n_expert_used
  7076. kernel void kernel_mul_mm_id_map0(
  7077. constant ggml_metal_kargs_mul_mm_id_map0 & args,
  7078. device const char * src2,
  7079. device char * htpe,
  7080. device char * hids,
  7081. threadgroup char * shmem [[threadgroup(0)]],
  7082. ushort tpitg[[thread_position_in_threadgroup]],
  7083. ushort ntg[[threads_per_threadgroup]]) {
  7084. const short ide = tpitg; // expert id
  7085. uint32_t n_all = 0;
  7086. device int32_t * ids_i32 = (device int32_t *) hids + ide*args.ne21;
  7087. for (int i21 = 0; i21 < args.ne21; i21 += ntg) { // n_tokens
  7088. if (i21 + tpitg < args.ne21) {
  7089. device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21);
  7090. threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*ne20;
  7091. #pragma unroll(ne20)
  7092. for (short i20 = 0; i20 < ne20; i20++) {
  7093. sids[i20] = src2_i32[i20];
  7094. }
  7095. }
  7096. threadgroup_barrier(mem_flags::mem_threadgroup);
  7097. for (short t = 0; t < ntg; t++) {
  7098. if (i21 + t >= args.ne21) {
  7099. break;
  7100. }
  7101. threadgroup const uint16_t * sids = (threadgroup const uint16_t *) shmem + t*ne20;
  7102. short sel = 0;
  7103. #pragma unroll(ne20)
  7104. for (short i20 = 0; i20 < ne20; i20++) {
  7105. sel += (sids[i20] == ide)*(i20 + 1);
  7106. }
  7107. ids_i32[n_all] = (i21 + t)*ne20 + sel - 1;
  7108. n_all += sel > 0;
  7109. }
  7110. threadgroup_barrier(mem_flags::mem_threadgroup);
  7111. }
  7112. device uint32_t * tpe_u32 = (device uint32_t *) (htpe);
  7113. tpe_u32[ide] = n_all;
  7114. }
  7115. typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;
  7116. template [[host_name("kernel_mul_mm_id_map0_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;
  7117. template [[host_name("kernel_mul_mm_id_map0_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;
  7118. template [[host_name("kernel_mul_mm_id_map0_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
  7119. template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
  7120. template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
  7121. template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
  7122. template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
  7123. template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
  7124. kernel void kernel_mul_mm_id(
  7125. constant ggml_metal_kargs_mul_mm_id & args,
  7126. device const char * src0,
  7127. device const char * src1,
  7128. device const char * htpe,
  7129. device const char * hids,
  7130. device char * dst,
  7131. threadgroup char * shmem [[threadgroup(0)]],
  7132. uint3 tgpig[[threadgroup_position_in_grid]],
  7133. ushort tiitg[[thread_index_in_threadgroup]],
  7134. ushort tiisg[[thread_index_in_simdgroup]],
  7135. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  7136. threadgroup S0 * sa = (threadgroup S0 *)(shmem);
  7137. threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
  7138. threadgroup float * sc = (threadgroup float *)(shmem);
  7139. constexpr int NR0 = 64;
  7140. constexpr int NR1 = 32;
  7141. constexpr int NK = 32;
  7142. constexpr int NL0 = NK/16;
  7143. constexpr int NL1 = NK/8;
  7144. const int im = tgpig.z; // expert
  7145. const int r0 = tgpig.y*NR0;
  7146. const int r1 = tgpig.x*NR1;
  7147. device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);
  7148. device const int32_t * ids_i32 = (device const int32_t *) (hids);
  7149. const int32_t neh1 = tpe_u32[im];
  7150. if (r1 >= neh1) {
  7151. return;
  7152. }
  7153. // if this block is of 64x32 shape or smaller
  7154. const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;
  7155. const short nr1 = ( neh1 - r1 < NR1) ? ( neh1 - r1) : NR1;
  7156. // a thread shouldn't load data outside of the matrix
  7157. const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
  7158. const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
  7159. const short il0 = (tiitg % NL0);
  7160. short il = il0;
  7161. const int id = ids_i32[im*args.ne21 + r1 + lr1];
  7162. const short i11 = (id % args.ne20) % args.ne11;
  7163. const short i12 = (id / args.ne20);
  7164. const short i13 = 0;
  7165. const uint64_t offset0 = im*args.nb02 + i13*args.nb03;
  7166. const short offset1 = il0/nl;
  7167. device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
  7168. const short iy = 8*(tiitg % NL1);
  7169. device const T1 * y = (device const T1 *)(src1
  7170. + args.nb13*i13
  7171. + args.nb12*i12
  7172. + args.nb11*i11
  7173. + args.nb10*iy);
  7174. #ifndef GGML_METAL_HAS_TENSOR
  7175. S0_8x8 ma[4];
  7176. S1_8x8 mb[2];
  7177. simdgroup_float8x8 mc[8];
  7178. for (short i = 0; i < 8; i++){
  7179. mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
  7180. }
  7181. #else
  7182. auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0));
  7183. auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));
  7184. mpp::tensor_ops::matmul2d<
  7185. mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
  7186. execution_simdgroups<4>> mm;
  7187. auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();
  7188. #endif
  7189. for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
  7190. #ifndef GGML_METAL_HAS_TENSOR
  7191. // load data and store to threadgroup memory
  7192. if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
  7193. threadgroup_barrier(mem_flags::mem_threadgroup);
  7194. // no need for dequantization
  7195. for (short i = 0; i < 16; i++) {
  7196. const short sx = 2*il0 + i/8;
  7197. const short sy = (tiitg/NL0)/8;
  7198. //const short lx = i%8;
  7199. //const short ly = (tiitg/NL0)%8;
  7200. const short lx = (tiitg/NL0)%8;
  7201. const short ly = i%8;
  7202. const short ib = 8*sx + sy;
  7203. *(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
  7204. }
  7205. } else {
  7206. S0_4x4 temp_a;
  7207. dequantize_func(x, il, temp_a);
  7208. threadgroup_barrier(mem_flags::mem_threadgroup);
  7209. FOR_UNROLL (short i = 0; i < 16; i++) {
  7210. const short sx = 2*il0 + i/8;
  7211. const short sy = (tiitg/NL0)/8;
  7212. //const short lx = i%8;
  7213. //const short ly = (tiitg/NL0)%8;
  7214. const short lx = (tiitg/NL0)%8;
  7215. const short ly = i%8;
  7216. const short ib = 8*sx + sy;
  7217. // NOTE: this is massively slower.. WTF?
  7218. //sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];
  7219. *(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4];
  7220. }
  7221. }
  7222. if (FC_mul_mm_bc_inp) {
  7223. for (short i = 0; i < 8; ++i) {
  7224. const short sx = (tiitg%NL1);
  7225. const short sy = (tiitg/NL1)/8;
  7226. const short lx = i;
  7227. const short ly = (tiitg/NL1)%8;
  7228. //const short lx = (tiitg/NL1)%8;
  7229. //const short ly = i;
  7230. const short ib = 4*sx + sy;
  7231. *(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
  7232. }
  7233. } else {
  7234. const short sx = (tiitg%NL1);
  7235. const short sy = (tiitg/NL1)/8;
  7236. const short dx = sx;
  7237. const short dy = sy;
  7238. const short ly = (tiitg/NL1)%8;
  7239. const short ib = 4*sx + sy;
  7240. *(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
  7241. }
  7242. #else
  7243. // load data and store to threadgroup memory
  7244. if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
  7245. threadgroup_barrier(mem_flags::mem_threadgroup);
  7246. // no need for dequantization
  7247. for (short i = 0; i < 16; i++) {
  7248. const short sx = 2*il0 + i/8;
  7249. const short sy = (tiitg/NL0)/8;
  7250. const short lx = i%8;
  7251. const short ly = (tiitg/NL0)%8;
  7252. //const short lx = (tiitg/NL0)%8;
  7253. //const short ly = i%8;
  7254. *(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
  7255. }
  7256. } else {
  7257. S0_4x4 temp_a;
  7258. dequantize_func(x, il, temp_a);
  7259. threadgroup_barrier(mem_flags::mem_threadgroup);
  7260. FOR_UNROLL (short i = 0; i < 16; i++) {
  7261. const short sx = 2*il0 + i/8;
  7262. const short sy = (tiitg/NL0)/8;
  7263. const short lx = i%8;
  7264. const short ly = (tiitg/NL0)%8;
  7265. //const short lx = (tiitg/NL0)%8;
  7266. //const short ly = i%8;
  7267. *(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4];
  7268. }
  7269. }
  7270. if (FC_mul_mm_bc_inp) {
  7271. for (short i = 0; i < 8; ++i) {
  7272. const short sx = (tiitg%NL1);
  7273. const short sy = (tiitg/NL1)/8;
  7274. const short lx = i;
  7275. const short ly = (tiitg/NL1)%8;
  7276. //const short lx = (tiitg/NL1)%8;
  7277. //const short ly = i;
  7278. *(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
  7279. }
  7280. } else {
  7281. const short sx = (tiitg%NL1);
  7282. const short sy = (tiitg/NL1)/8;
  7283. //const short lx = i;
  7284. const short ly = (tiitg/NL1)%8;
  7285. //const short lx = (tiitg/NL1)%8;
  7286. //const short ly = i;
  7287. *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));
  7288. }
  7289. #endif
  7290. il = (il + 2 < nl) ? il + 2 : il % 2;
  7291. x = (il < 2) ? x + (2 + nl - 1)/nl : x;
  7292. y += NK;
  7293. threadgroup_barrier(mem_flags::mem_threadgroup);
  7294. #ifndef GGML_METAL_HAS_TENSOR
  7295. // load matrices from threadgroup memory and conduct outer products
  7296. threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
  7297. threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
  7298. FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
  7299. simdgroup_barrier(mem_flags::mem_none);
  7300. FOR_UNROLL (short i = 0; i < 4; i++) {
  7301. simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
  7302. }
  7303. simdgroup_barrier(mem_flags::mem_none);
  7304. FOR_UNROLL (short i = 0; i < 2; i++) {
  7305. simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
  7306. }
  7307. simdgroup_barrier(mem_flags::mem_none);
  7308. FOR_UNROLL (short i = 0; i < 8; i++){
  7309. simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
  7310. }
  7311. lsma += 8*64;
  7312. lsmb += 4*64;
  7313. }
  7314. #else
  7315. auto sA = tA.slice(0, 0);
  7316. auto sB = tB.slice(0, 0);
  7317. mm.run(sB, sA, cT);
  7318. #endif
  7319. }
  7320. // block is smaller than 64x32, we should avoid writing data outside of the matrix
  7321. threadgroup_barrier(mem_flags::mem_threadgroup);
  7322. #ifdef GGML_METAL_HAS_TENSOR
  7323. auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));
  7324. cT.store(tC);
  7325. #else
  7326. threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
  7327. for (short i = 0; i < 8; i++) {
  7328. simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
  7329. }
  7330. #endif
  7331. threadgroup_barrier(mem_flags::mem_threadgroup);
  7332. for (short j = sgitg; j < nr1; j += 4) {
  7333. const int id = ids_i32[im*args.ne21 + r1 + j];
  7334. const short ide = id % args.ne20;
  7335. const short idt = id / args.ne20;
  7336. device float * D = (device float *) dst + r0 + ide*args.ne0 + idt*args.ne1*args.ne0;
  7337. device float4 * D4 = (device float4 *) D;
  7338. threadgroup float * C = (threadgroup float *) shmem + j*NR0;
  7339. threadgroup float4 * C4 = (threadgroup float4 *) C;
  7340. int i = tiisg;
  7341. for (; i < nr0/4; i += 32) {
  7342. *(D4 + i) = *(C4 + i);
  7343. }
  7344. i = (4*(nr0/4)) + tiisg;
  7345. for (; i < nr0; i += 32) {
  7346. *(D + i) = *(C + i);
  7347. }
  7348. }
  7349. }
  7350. #define QK_NL 16
  7351. //
  7352. // get rows
  7353. //
  7354. typedef decltype(kernel_get_rows_f<float, float>) get_rows_f_t;
  7355. template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float, float>;
  7356. template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half, float>;
  7357. template [[host_name("kernel_get_rows_i32")]] kernel get_rows_f_t kernel_get_rows_f<int32_t, int32_t>;
  7358. #if defined(GGML_METAL_HAS_BF16)
  7359. template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat, float>;
  7360. #endif
  7361. typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
  7362. template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
  7363. template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
  7364. template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
  7365. template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
  7366. template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
  7367. template [[host_name("kernel_get_rows_mxfp4")]] kernel get_rows_q_t kernel_get_rows_q<block_mxfp4, 2, dequantize_mxfp4>;
  7368. template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
  7369. template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
  7370. template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
  7371. template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_K, QK_NL, dequantize_q5_K>;
  7372. template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q6_K, QK_NL, dequantize_q6_K>;
  7373. template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
  7374. template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
  7375. template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
  7376. template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_s, QK_NL, dequantize_iq3_s>;
  7377. template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_s, QK_NL, dequantize_iq2_s>;
  7378. template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_s, QK_NL, dequantize_iq1_s>;
  7379. template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_m, QK_NL, dequantize_iq1_m>;
  7380. template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
  7381. template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
  7382. //
  7383. // set rows
  7384. //
  7385. typedef decltype(kernel_set_rows_f<float, int64_t>) set_rows_f_t;
  7386. template [[host_name("kernel_set_rows_f32_i64")]] kernel set_rows_f_t kernel_set_rows_f<float, int64_t>;
  7387. template [[host_name("kernel_set_rows_f32_i32")]] kernel set_rows_f_t kernel_set_rows_f<float, int32_t>;
  7388. template [[host_name("kernel_set_rows_f16_i64")]] kernel set_rows_f_t kernel_set_rows_f<half, int64_t>;
  7389. template [[host_name("kernel_set_rows_f16_i32")]] kernel set_rows_f_t kernel_set_rows_f<half, int32_t>;
  7390. #if defined(GGML_METAL_HAS_BF16)
  7391. template [[host_name("kernel_set_rows_bf16_i64")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int64_t>;
  7392. template [[host_name("kernel_set_rows_bf16_i32")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int32_t>;
  7393. #endif
  7394. typedef decltype(kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>) set_rows_q32_t;
  7395. template [[host_name("kernel_set_rows_q8_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>;
  7396. template [[host_name("kernel_set_rows_q8_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q8_0, quantize_q8_0>;
  7397. template [[host_name("kernel_set_rows_q4_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_0, quantize_q4_0>;
  7398. template [[host_name("kernel_set_rows_q4_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_0, quantize_q4_0>;
  7399. template [[host_name("kernel_set_rows_q4_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_1, quantize_q4_1>;
  7400. template [[host_name("kernel_set_rows_q4_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_1, quantize_q4_1>;
  7401. template [[host_name("kernel_set_rows_q5_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_0, quantize_q5_0>;
  7402. template [[host_name("kernel_set_rows_q5_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_0, quantize_q5_0>;
  7403. template [[host_name("kernel_set_rows_q5_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_1, quantize_q5_1>;
  7404. template [[host_name("kernel_set_rows_q5_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_1, quantize_q5_1>;
  7405. template [[host_name("kernel_set_rows_iq4_nl_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_iq4_nl, quantize_iq4_nl>;
  7406. template [[host_name("kernel_set_rows_iq4_nl_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_iq4_nl, quantize_iq4_nl>;
  7407. //
  7408. // matrix-matrix multiplication
  7409. //
  7410. typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_t;
  7411. template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>;
  7412. template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, float, float2x4>;
  7413. #if defined(GGML_METAL_HAS_BF16)
  7414. template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
  7415. #endif
  7416. template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
  7417. template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
  7418. template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
  7419. template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>;
  7420. template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
  7421. template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>;
  7422. template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
  7423. template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>;
  7424. template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
  7425. template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, float, float2x4>;
  7426. template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, float, float2x4>;
  7427. template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>;
  7428. template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, float, float2x4>;
  7429. template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, float, float2x4>;
  7430. template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, float, float2x4>;
  7431. template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, float, float2x4>;
  7432. template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, float, float2x4>;
  7433. template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, float, float2x4>;
  7434. template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, float, float2x4>;
  7435. template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, float, float2x4>;
  7436. template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
  7437. template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
  7438. #if defined(GGML_METAL_HAS_BF16)
  7439. template [[host_name("kernel_mul_mm_bf16_f16")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, half, half2x4, simdgroup_half8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, half, half2x4>;
  7440. #endif
  7441. template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
  7442. template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
  7443. template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
  7444. template [[host_name("kernel_mul_mm_q5_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>;
  7445. template [[host_name("kernel_mul_mm_q8_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>;
  7446. template [[host_name("kernel_mul_mm_mxfp4_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>;
  7447. template [[host_name("kernel_mul_mm_q2_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>;
  7448. template [[host_name("kernel_mul_mm_q3_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>;
  7449. template [[host_name("kernel_mul_mm_q4_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>;
  7450. template [[host_name("kernel_mul_mm_q5_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, half, half2x4>;
  7451. template [[host_name("kernel_mul_mm_q6_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, half, half2x4>;
  7452. template [[host_name("kernel_mul_mm_iq2_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, half, half2x4>;
  7453. template [[host_name("kernel_mul_mm_iq2_xs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, half, half2x4>;
  7454. template [[host_name("kernel_mul_mm_iq3_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, half, half2x4>;
  7455. template [[host_name("kernel_mul_mm_iq3_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, half, half2x4>;
  7456. template [[host_name("kernel_mul_mm_iq2_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, half, half2x4>;
  7457. template [[host_name("kernel_mul_mm_iq1_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, half, half2x4>;
  7458. template [[host_name("kernel_mul_mm_iq1_m_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, half, half2x4>;
  7459. template [[host_name("kernel_mul_mm_iq4_nl_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, half, half2x4>;
  7460. template [[host_name("kernel_mul_mm_iq4_xs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, half, half2x4>;
  7461. //
  7462. // indirect matrix-matrix multiplication
  7463. //
  7464. typedef decltype(kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_id;
  7465. template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>;
  7466. template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, float, float2x4>;
  7467. #if defined(GGML_METAL_HAS_BF16)
  7468. template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
  7469. #endif
  7470. template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
  7471. template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
  7472. template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
  7473. template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>;
  7474. template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
  7475. template [[host_name("kernel_mul_mm_id_mxfp4_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>;
  7476. template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
  7477. template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>;
  7478. template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
  7479. template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, float, float2x4>;
  7480. template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, float, float2x4>;
  7481. template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>;
  7482. template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, float, float2x4>;
  7483. template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, float, float2x4>;
  7484. template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, float, float2x4>;
  7485. template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, float, float2x4>;
  7486. template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, float, float2x4>;
  7487. template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, float, float2x4>;
  7488. template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, float, float2x4>;
  7489. template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, float, float2x4>;
  7490. template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
  7491. template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
  7492. #if defined(GGML_METAL_HAS_BF16)
  7493. template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, half, half2x4, simdgroup_half8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, half, half2x4>;
  7494. #endif
  7495. template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
  7496. template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
  7497. template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
  7498. template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>;
  7499. template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>;
  7500. template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>;
  7501. template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>;
  7502. template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>;
  7503. template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>;
  7504. template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, half, half2x4>;
  7505. template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, half, half2x4>;
  7506. template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, half, half2x4>;
  7507. template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, half, half2x4>;
  7508. template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, half, half2x4>;
  7509. template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, half, half2x4>;
  7510. template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, half, half2x4>;
  7511. template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, half, half2x4>;
  7512. template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, half, half2x4>;
  7513. template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, half, half2x4>;
  7514. template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, half, half2x4>;
  7515. //
  7516. // matrix-vector multiplication
  7517. //
  7518. typedef void (kernel_mul_mv_disp_t)(
  7519. ggml_metal_kargs_mul_mv args,
  7520. device const char * src0,
  7521. device const char * src1,
  7522. device char * dst,
  7523. uint3 tgpig,
  7524. ushort tiisg);
  7525. typedef void (kernel_mul_mv2_disp_t)(
  7526. ggml_metal_kargs_mul_mv args,
  7527. device const char * src0,
  7528. device const char * src1,
  7529. device char * dst,
  7530. threadgroup char * shmem,
  7531. uint3 tgpig,
  7532. ushort tiisg,
  7533. ushort sgitg);
  7534. template<kernel_mul_mv_disp_t disp_fn>
  7535. void mmv_fn(
  7536. ggml_metal_kargs_mul_mv args,
  7537. device const char * src0,
  7538. device const char * src1,
  7539. device char * dst,
  7540. threadgroup char * shmem,
  7541. uint3 tgpig,
  7542. ushort tiitg,
  7543. ushort tiisg,
  7544. ushort sgitg) {
  7545. disp_fn(args, src0, src1, dst, tgpig, tiisg);
  7546. }
  7547. template<kernel_mul_mv2_disp_t disp_fn>
  7548. void mmv_fn(
  7549. ggml_metal_kargs_mul_mv args,
  7550. device const char * src0,
  7551. device const char * src1,
  7552. device char * dst,
  7553. threadgroup char * shmem,
  7554. uint3 tgpig,
  7555. ushort tiitg,
  7556. ushort tiisg,
  7557. ushort sgitg) {
  7558. disp_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
  7559. }
  7560. typedef decltype(mmv_fn<kernel_mul_mv_t_t_disp<half, half, ggml_metal_kargs_mul_mv>>) mul_mv_disp_fn_t;
  7561. template<mul_mv_disp_fn_t disp_fn>
  7562. kernel void kernel_mul_mv_id(
  7563. constant ggml_metal_kargs_mul_mv_id & args,
  7564. device const char * src0s,
  7565. device const char * src1,
  7566. device char * dst,
  7567. device const char * ids,
  7568. threadgroup char * shmem [[threadgroup(0)]],
  7569. uint3 tgpig[[threadgroup_position_in_grid]],
  7570. ushort tiitg[[thread_index_in_threadgroup]],
  7571. ushort tiisg[[thread_index_in_simdgroup]],
  7572. ushort sgitg[[simdgroup_index_in_threadgroup]]) {
  7573. const int iid1 = tgpig.z/args.nei0;
  7574. const int idx = tgpig.z%args.nei0;
  7575. tgpig.z = 0;
  7576. const int32_t i02 = ((device const int32_t *) (ids + iid1*args.nbi1))[idx];
  7577. const int64_t i11 = idx % args.ne11;
  7578. const int64_t i12 = iid1;
  7579. const int64_t i1 = idx;
  7580. const int64_t i2 = i12;
  7581. device const char * src0_cur = src0s + i02*args.nb02;
  7582. device const char * src1_cur = src1 + i11*args.nb11 + i12*args.nb12;
  7583. device char * dst_cur = dst + (i1*args.ne0 + i2*args.ne1*args.ne0)*sizeof(float);
  7584. ggml_metal_kargs_mul_mv args0 = {
  7585. /*.ne00 =*/ args.ne00,
  7586. /*.ne01 =*/ args.ne01,
  7587. /*.ne02 =*/ 1, // args.ne02,
  7588. /*.nb00 =*/ args.nb00,
  7589. /*.nb01 =*/ args.nb01,
  7590. /*.nb02 =*/ args.nb02,
  7591. /*.nb03 =*/ args.nb02, // args.ne02 == 1
  7592. /*.ne10 =*/ args.ne10,
  7593. /*.ne11 =*/ 1, // args.ne11,
  7594. /*.ne12 =*/ 1, // args.ne12,
  7595. /*.nb10 =*/ args.nb10,
  7596. /*.nb11 =*/ args.nb11,
  7597. /*.nb12 =*/ args.nb12,
  7598. /*.nb13 =*/ args.nb12, // ne12 == 1
  7599. /*.ne0 =*/ args.ne0,
  7600. /*.ne1 =*/ 1, // args.ne1,
  7601. /*.nr0 =*/ args.nr0,
  7602. /*.r2 =*/ 1,
  7603. /*.r3 =*/ 1,
  7604. };
  7605. disp_fn(
  7606. args0,
  7607. /* src0 */ src0_cur,
  7608. /* src1 */ src1_cur,
  7609. /* dst */ dst_cur,
  7610. shmem,
  7611. tgpig,
  7612. tiitg,
  7613. tiisg,
  7614. sgitg);
  7615. }
  7616. typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<float, float>>>) kernel_mul_mv_id_t;
  7617. typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<float, float4, float, float4>>>) kernel_mul_mv_id_4_t;
  7618. template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<float, float>>>;
  7619. template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<half, float>>>;
  7620. #if defined(GGML_METAL_HAS_BF16)
  7621. template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<bfloat, float>>>;
  7622. #endif
  7623. template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<float, float4, float, float4>>>;
  7624. template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<half, half4, float, float4>>>;
  7625. #if defined(GGML_METAL_HAS_BF16)
  7626. template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<bfloat, bfloat4, float, float4>>>;
  7627. #endif
  7628. template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>;
  7629. template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0>>>;
  7630. template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1>>>;
  7631. template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0>>>;
  7632. template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1>>>;
  7633. template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4>>>;
  7634. template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl <N_R0_Q2_K>>>;
  7635. template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl <N_R0_Q3_K>>>;
  7636. template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl <N_R0_Q4_K>>>;
  7637. template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl <N_R0_Q5_K>>>;
  7638. template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl <N_R0_Q6_K>>>;
  7639. template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl <N_R0_IQ1_S>>>;
  7640. template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl <N_R0_IQ1_M>>>;
  7641. template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS>>>;
  7642. template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl <N_R0_IQ2_XS>>>;
  7643. template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS>>>;
  7644. template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl <N_R0_IQ3_S>>>;
  7645. template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl <N_R0_IQ2_S>>>;
  7646. template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL>>>;
  7647. template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS>>>;
  7648. kernel void kernel_pool_2d_max_f32(
  7649. constant ggml_metal_kargs_pool_2d & args,
  7650. device const float * src0,
  7651. device float * dst,
  7652. uint gid[[thread_position_in_grid]]) {
  7653. if (gid >= args.np) {
  7654. return;
  7655. }
  7656. const int idx = gid;
  7657. const int I_HW = args.IH * args.IW;
  7658. const int O_HW = args.OH * args.OW;
  7659. const int nc = idx / O_HW;
  7660. const int cur_oh = idx % O_HW / args.OW;
  7661. const int cur_ow = idx % O_HW % args.OW;
  7662. device const float * i_ptr = src0 + nc * I_HW;
  7663. device float * o_ptr = dst + nc * O_HW;
  7664. const int start_h = cur_oh * args.s1 - args.p1;
  7665. const int bh = MAX(0, start_h);
  7666. const int eh = MIN(args.IH, start_h + args.k1);
  7667. const int start_w = cur_ow * args.s0 - args.p0;
  7668. const int bw = MAX(0, start_w);
  7669. const int ew = MIN(args.IW, start_w + args.k0);
  7670. float res = -INFINITY;
  7671. for (int i = bh; i < eh; i += 1) {
  7672. for (int j = bw; j < ew; j += 1) {
  7673. res = MAX(res, i_ptr[i * args.IW + j]);
  7674. }
  7675. }
  7676. o_ptr[cur_oh * args.OW + cur_ow] = res;
  7677. }
  7678. kernel void kernel_pool_2d_avg_f32(
  7679. constant ggml_metal_kargs_pool_2d & args,
  7680. device const float * src0,
  7681. device float * dst,
  7682. uint gid[[thread_position_in_grid]]) {
  7683. if (gid >= args.np) {
  7684. return;
  7685. }
  7686. const int idx = gid;
  7687. const int I_HW = args.IH * args.IW;
  7688. const int O_HW = args.OH * args.OW;
  7689. const int nc = idx / O_HW;
  7690. const int cur_oh = idx % O_HW / args.OW;
  7691. const int cur_ow = idx % O_HW % args.OW;
  7692. device const float * i_ptr = src0 + nc * I_HW;
  7693. device float * o_ptr = dst + nc * O_HW;
  7694. const int start_h = cur_oh * args.s1 - args.p1;
  7695. const int bh = MAX(0, start_h);
  7696. const int eh = MIN(args.IH, start_h + args.k1);
  7697. const int start_w = cur_ow * args.s0 - args.p0;
  7698. const int bw = MAX(0, start_w);
  7699. const int ew = MIN(args.IW, start_w + args.k0);
  7700. // const float scale = 1. / ((eh - bh) * (ew - bw));
  7701. const float scale = 1. / (args.k0 * args.k1);
  7702. float res = 0;
  7703. for (int i = bh; i < eh; i += 1) {
  7704. for (int j = bw; j < ew; j += 1) {
  7705. float cur = i_ptr[i * args.IW + j];
  7706. res += cur * scale;
  7707. }
  7708. }
  7709. o_ptr[cur_oh * args.OW + cur_ow] = res;
  7710. }
  7711. kernel void kernel_opt_step_adamw_f32(
  7712. constant ggml_metal_kargs_opt_step_adamw & args,
  7713. device float * x,
  7714. device const float * g,
  7715. device float * g_m,
  7716. device float * g_v,
  7717. device const float * pars,
  7718. uint gid[[thread_position_in_grid]]) {
  7719. if (gid >= args.np) {
  7720. return;
  7721. }
  7722. const float alpha = pars[0];
  7723. const float beta1 = pars[1];
  7724. const float beta2 = pars[2];
  7725. const float eps = pars[3];
  7726. const float wd = pars[4];
  7727. const float beta1h = pars[5];
  7728. const float beta2h = pars[6];
  7729. const float gi = g[gid];
  7730. const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1);
  7731. const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2);
  7732. g_m[gid] = gmi;
  7733. g_v[gid] = gvi;
  7734. const float mh = gmi * beta1h;
  7735. const float vh = sqrt(gvi * beta2h) + eps;
  7736. x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh;
  7737. }
  7738. kernel void kernel_opt_step_sgd_f32(
  7739. constant ggml_metal_kargs_opt_step_sgd & args,
  7740. device float * x,
  7741. device const float * g,
  7742. device const float * pars,
  7743. uint gid[[thread_position_in_grid]]) {
  7744. if (gid >= args.np) {
  7745. return;
  7746. }
  7747. x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
  7748. }