| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073507450755076507750785079508050815082508350845085508650875088508950905091509250935094509550965097509850995100510151025103510451055106510751085109511051115112511351145115511651175118511951205121512251235124512551265127512851295130513151325133513451355136513751385139514051415142514351445145514651475148514951505151515251535154515551565157515851595160516151625163516451655166516751685169517051715172517351745175517651775178517951805181518251835184518551865187518851895190519151925193519451955196519751985199520052015202520352045205520652075208520952105211521252135214521552165217521852195220522152225223522452255226522752285229523052315232523352345235523652375238523952405241524252435244524552465247524852495250525152525253525452555256525752585259526052615262526352645265526652675268526952705271527252735274527552765277527852795280528152825283528452855286528752885289529052915292529352945295529652975298529953005301530253035304530553065307530853095310531153125313531453155316531753185319532053215322532353245325532653275328532953305331533253335334533553365337533853395340534153425343534453455346534753485349535053515352535353545355535653575358535953605361536253635364536553665367536853695370537153725373537453755376537753785379538053815382538353845385538653875388538953905391539253935394539553965397539853995400540154025403540454055406540754085409541054115412541354145415541654175418541954205421542254235424542554265427542854295430543154325433543454355436543754385439544054415442544354445445544654475448544954505451545254535454545554565457545854595460546154625463546454655466546754685469547054715472547354745475547654775478547954805481548254835484548554865487548854895490549154925493549454955496549754985499550055015502550355045505550655075508550955105511551255135514551555165517551855195520552155225523552455255526552755285529553055315532553355345535553655375538553955405541554255435544554555465547554855495550555155525553555455555556555755585559556055615562556355645565556655675568556955705571557255735574557555765577557855795580558155825583558455855586558755885589559055915592559355945595559655975598559956005601560256035604560556065607560856095610561156125613561456155616561756185619562056215622562356245625562656275628562956305631563256335634563556365637563856395640564156425643564456455646564756485649565056515652565356545655565656575658565956605661566256635664566556665667566856695670567156725673567456755676567756785679568056815682568356845685568656875688568956905691569256935694569556965697569856995700570157025703570457055706570757085709571057115712571357145715571657175718571957205721572257235724572557265727572857295730573157325733573457355736573757385739574057415742574357445745574657475748574957505751575257535754575557565757575857595760576157625763576457655766576757685769577057715772577357745775577657775778577957805781578257835784578557865787578857895790579157925793579457955796579757985799580058015802580358045805580658075808580958105811581258135814581558165817581858195820582158225823582458255826582758285829583058315832583358345835583658375838583958405841584258435844584558465847584858495850585158525853585458555856585758585859586058615862586358645865586658675868586958705871587258735874587558765877587858795880588158825883588458855886588758885889589058915892589358945895589658975898589959005901590259035904590559065907590859095910591159125913591459155916591759185919592059215922592359245925592659275928592959305931593259335934593559365937593859395940594159425943594459455946594759485949595059515952595359545955595659575958595959605961596259635964596559665967596859695970597159725973597459755976597759785979598059815982598359845985598659875988598959905991599259935994599559965997599859996000600160026003600460056006600760086009601060116012601360146015601660176018601960206021602260236024602560266027602860296030603160326033603460356036603760386039604060416042604360446045604660476048604960506051605260536054605560566057605860596060606160626063606460656066606760686069607060716072607360746075607660776078607960806081608260836084608560866087608860896090609160926093609460956096609760986099610061016102610361046105610661076108610961106111611261136114611561166117611861196120612161226123612461256126612761286129613061316132613361346135613661376138613961406141614261436144614561466147614861496150615161526153615461556156615761586159616061616162616361646165616661676168616961706171617261736174617561766177617861796180618161826183618461856186618761886189619061916192619361946195619661976198619962006201620262036204620562066207620862096210621162126213621462156216621762186219622062216222622362246225622662276228622962306231623262336234623562366237623862396240624162426243624462456246624762486249625062516252625362546255625662576258625962606261626262636264626562666267626862696270627162726273627462756276627762786279628062816282628362846285628662876288628962906291629262936294629562966297629862996300630163026303630463056306630763086309631063116312631363146315631663176318631963206321632263236324632563266327632863296330633163326333633463356336633763386339634063416342634363446345634663476348634963506351635263536354635563566357635863596360636163626363636463656366636763686369637063716372637363746375637663776378637963806381638263836384638563866387638863896390639163926393639463956396639763986399640064016402640364046405640664076408640964106411641264136414641564166417641864196420642164226423642464256426642764286429643064316432643364346435643664376438643964406441644264436444644564466447644864496450645164526453645464556456645764586459646064616462646364646465646664676468646964706471647264736474647564766477647864796480648164826483648464856486648764886489649064916492649364946495649664976498649965006501650265036504650565066507650865096510651165126513651465156516651765186519652065216522652365246525652665276528652965306531653265336534653565366537653865396540654165426543654465456546654765486549655065516552655365546555655665576558655965606561656265636564656565666567656865696570657165726573657465756576657765786579658065816582658365846585658665876588658965906591659265936594659565966597659865996600660166026603660466056606660766086609661066116612661366146615661666176618661966206621662266236624662566266627662866296630663166326633663466356636663766386639664066416642664366446645664666476648664966506651665266536654665566566657665866596660666166626663666466656666666766686669667066716672667366746675667666776678667966806681668266836684668566866687668866896690669166926693669466956696669766986699670067016702670367046705670667076708670967106711671267136714671567166717671867196720672167226723672467256726672767286729673067316732673367346735673667376738673967406741674267436744674567466747674867496750675167526753675467556756675767586759676067616762676367646765676667676768676967706771677267736774677567766777677867796780678167826783678467856786678767886789679067916792679367946795679667976798679968006801680268036804680568066807680868096810681168126813681468156816681768186819682068216822682368246825682668276828682968306831683268336834683568366837683868396840684168426843684468456846684768486849685068516852685368546855685668576858685968606861686268636864686568666867686868696870687168726873687468756876687768786879688068816882688368846885688668876888688968906891689268936894689568966897689868996900690169026903690469056906690769086909691069116912691369146915691669176918691969206921692269236924692569266927692869296930693169326933693469356936693769386939694069416942694369446945694669476948694969506951695269536954695569566957695869596960696169626963696469656966696769686969697069716972697369746975697669776978697969806981698269836984698569866987698869896990699169926993699469956996699769986999700070017002700370047005700670077008700970107011701270137014701570167017701870197020702170227023702470257026702770287029703070317032703370347035703670377038703970407041704270437044704570467047704870497050705170527053705470557056705770587059706070617062706370647065706670677068706970707071707270737074707570767077707870797080708170827083708470857086708770887089709070917092709370947095709670977098709971007101710271037104710571067107710871097110711171127113711471157116711771187119712071217122712371247125712671277128712971307131713271337134713571367137713871397140714171427143714471457146714771487149715071517152715371547155715671577158715971607161716271637164716571667167716871697170717171727173717471757176717771787179718071817182718371847185718671877188718971907191719271937194719571967197719871997200720172027203720472057206720772087209721072117212721372147215721672177218721972207221722272237224722572267227722872297230723172327233723472357236723772387239724072417242724372447245724672477248724972507251725272537254725572567257725872597260726172627263726472657266726772687269727072717272727372747275727672777278727972807281728272837284728572867287728872897290729172927293729472957296729772987299730073017302730373047305730673077308730973107311731273137314731573167317731873197320732173227323732473257326732773287329733073317332733373347335733673377338733973407341734273437344734573467347734873497350735173527353735473557356735773587359736073617362736373647365736673677368736973707371737273737374737573767377737873797380738173827383738473857386738773887389739073917392739373947395739673977398739974007401740274037404740574067407740874097410741174127413741474157416741774187419742074217422742374247425742674277428742974307431743274337434743574367437743874397440744174427443744474457446744774487449745074517452745374547455745674577458745974607461746274637464746574667467746874697470747174727473747474757476747774787479748074817482748374847485748674877488748974907491749274937494749574967497749874997500750175027503750475057506750775087509751075117512751375147515751675177518751975207521752275237524752575267527752875297530753175327533753475357536753775387539754075417542754375447545754675477548754975507551755275537554755575567557755875597560756175627563756475657566756775687569757075717572757375747575757675777578757975807581758275837584758575867587758875897590759175927593759475957596759775987599760076017602760376047605760676077608760976107611761276137614761576167617761876197620762176227623762476257626762776287629763076317632763376347635763676377638763976407641764276437644764576467647764876497650765176527653765476557656765776587659766076617662766376647665766676677668766976707671767276737674767576767677767876797680768176827683768476857686768776887689769076917692769376947695769676977698769977007701770277037704770577067707770877097710771177127713771477157716771777187719772077217722772377247725772677277728772977307731773277337734773577367737773877397740774177427743774477457746774777487749775077517752775377547755775677577758775977607761776277637764776577667767776877697770777177727773777477757776777777787779778077817782778377847785778677877788778977907791779277937794779577967797779877997800780178027803780478057806780778087809781078117812781378147815781678177818781978207821782278237824782578267827782878297830783178327833783478357836783778387839784078417842784378447845784678477848784978507851785278537854785578567857785878597860786178627863786478657866786778687869787078717872787378747875787678777878787978807881788278837884788578867887788878897890789178927893789478957896789778987899790079017902790379047905790679077908790979107911791279137914791579167917791879197920792179227923792479257926792779287929793079317932793379347935793679377938793979407941794279437944794579467947794879497950795179527953795479557956795779587959796079617962796379647965796679677968796979707971797279737974797579767977797879797980798179827983798479857986798779887989799079917992799379947995799679977998799980008001800280038004800580068007800880098010801180128013801480158016801780188019802080218022802380248025802680278028802980308031803280338034803580368037803880398040804180428043804480458046804780488049805080518052805380548055805680578058805980608061806280638064806580668067806880698070807180728073807480758076807780788079808080818082808380848085808680878088808980908091809280938094809580968097809880998100810181028103810481058106810781088109811081118112811381148115811681178118811981208121812281238124812581268127812881298130813181328133813481358136813781388139814081418142814381448145814681478148814981508151815281538154815581568157815881598160816181628163816481658166816781688169817081718172817381748175817681778178817981808181818281838184818581868187818881898190819181928193819481958196819781988199820082018202820382048205820682078208820982108211821282138214821582168217821882198220822182228223822482258226822782288229823082318232823382348235823682378238823982408241824282438244824582468247824882498250825182528253825482558256825782588259826082618262826382648265826682678268826982708271827282738274827582768277827882798280828182828283828482858286828782888289829082918292829382948295829682978298829983008301830283038304830583068307830883098310831183128313831483158316831783188319832083218322832383248325832683278328832983308331833283338334833583368337833883398340834183428343834483458346834783488349835083518352835383548355835683578358835983608361836283638364836583668367836883698370837183728373837483758376837783788379838083818382838383848385838683878388838983908391839283938394839583968397839883998400840184028403840484058406840784088409841084118412841384148415841684178418841984208421842284238424842584268427842884298430843184328433843484358436843784388439844084418442844384448445844684478448844984508451845284538454845584568457845884598460846184628463846484658466846784688469847084718472847384748475847684778478847984808481848284838484848584868487848884898490849184928493849484958496849784988499850085018502850385048505850685078508850985108511851285138514851585168517851885198520852185228523852485258526852785288529853085318532853385348535853685378538853985408541854285438544854585468547854885498550855185528553855485558556855785588559856085618562856385648565856685678568856985708571857285738574857585768577857885798580858185828583858485858586858785888589859085918592859385948595859685978598859986008601860286038604860586068607860886098610861186128613861486158616861786188619862086218622862386248625862686278628862986308631863286338634863586368637863886398640864186428643864486458646864786488649865086518652865386548655865686578658865986608661866286638664866586668667866886698670867186728673867486758676867786788679868086818682868386848685868686878688868986908691869286938694869586968697869886998700870187028703870487058706870787088709871087118712871387148715871687178718871987208721872287238724872587268727872887298730873187328733873487358736873787388739874087418742874387448745874687478748874987508751875287538754875587568757875887598760876187628763876487658766876787688769877087718772877387748775877687778778877987808781878287838784878587868787878887898790879187928793879487958796879787988799880088018802880388048805880688078808880988108811881288138814881588168817881888198820882188228823882488258826882788288829883088318832883388348835883688378838883988408841884288438844884588468847884888498850885188528853885488558856885788588859886088618862886388648865886688678868886988708871887288738874887588768877887888798880888188828883888488858886888788888889889088918892889388948895889688978898889989008901890289038904890589068907890889098910891189128913891489158916891789188919892089218922892389248925892689278928892989308931893289338934893589368937893889398940894189428943894489458946894789488949895089518952895389548955895689578958895989608961896289638964896589668967896889698970897189728973897489758976897789788979898089818982898389848985898689878988898989908991899289938994899589968997899889999000900190029003900490059006900790089009901090119012901390149015901690179018901990209021902290239024902590269027902890299030903190329033903490359036903790389039904090419042904390449045904690479048904990509051905290539054905590569057905890599060906190629063906490659066906790689069907090719072907390749075907690779078907990809081908290839084908590869087908890899090909190929093909490959096909790989099910091019102910391049105910691079108910991109111911291139114911591169117911891199120912191229123912491259126912791289129913091319132913391349135913691379138913991409141914291439144914591469147914891499150915191529153915491559156915791589159916091619162916391649165916691679168916991709171917291739174917591769177917891799180918191829183918491859186918791889189919091919192919391949195919691979198919992009201920292039204920592069207920892099210921192129213921492159216921792189219922092219222922392249225922692279228922992309231923292339234923592369237923892399240924192429243924492459246924792489249925092519252925392549255925692579258925992609261926292639264926592669267926892699270927192729273927492759276927792789279928092819282928392849285928692879288928992909291929292939294929592969297929892999300930193029303930493059306930793089309931093119312931393149315931693179318931993209321932293239324932593269327932893299330933193329333933493359336933793389339934093419342934393449345934693479348934993509351935293539354935593569357935893599360936193629363936493659366936793689369937093719372937393749375937693779378937993809381938293839384938593869387938893899390939193929393939493959396939793989399940094019402940394049405940694079408940994109411941294139414941594169417941894199420942194229423942494259426942794289429943094319432943394349435943694379438943994409441944294439444944594469447944894499450945194529453945494559456945794589459946094619462946394649465946694679468946994709471947294739474947594769477947894799480948194829483948494859486948794889489949094919492949394949495949694979498949995009501950295039504950595069507950895099510951195129513951495159516951795189519952095219522952395249525952695279528952995309531953295339534953595369537953895399540954195429543954495459546954795489549955095519552955395549555955695579558955995609561956295639564956595669567956895699570957195729573957495759576957795789579958095819582958395849585958695879588958995909591959295939594959595969597959895999600960196029603960496059606960796089609961096119612961396149615961696179618961996209621962296239624962596269627962896299630963196329633963496359636963796389639964096419642964396449645964696479648964996509651965296539654965596569657965896599660966196629663966496659666966796689669967096719672967396749675967696779678967996809681 |
- #define GGML_COMMON_DECL_METAL
- #define GGML_COMMON_IMPL_METAL
- #if defined(GGML_METAL_EMBED_LIBRARY)
- __embed_ggml-common.h__
- #else
- #include "ggml-common.h"
- #endif
- #include "ggml-metal-impl.h"
- #include <metal_stdlib>
- #ifdef GGML_METAL_HAS_TENSOR
- #include <metal_tensor>
- #include <MetalPerformancePrimitives/MetalPerformancePrimitives.h>
- #endif
- using namespace metal;
- #define MAX(x, y) ((x) > (y) ? (x) : (y))
- #define MIN(x, y) ((x) < (y) ? (x) : (y))
- #define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
- #define PAD2(x, n) (((x) + (n) - 1) & ~((n) - 1))
- #define FOR_UNROLL(x) _Pragma("clang loop unroll(full)") for (x)
- #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
- // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
- //
- // cmd:
- // .../usr/bin/metal -dM -E -c ggml/src/ggml-metal/ggml-metal.metal
- // .../usr/bin/metal -dM -E -c -target air64-apple-ios14.0 ggml/src/ggml-metal/ggml-metal.metal
- //
- #if __METAL_VERSION__ < 310 && defined(GGML_METAL_HAS_BF16)
- #undef GGML_METAL_HAS_BF16
- #endif
- #if defined(GGML_METAL_HAS_BF16)
- typedef matrix<bfloat, 4, 4> bfloat4x4;
- typedef matrix<bfloat, 2, 4> bfloat2x4;
- #endif
- constexpr constant static float kvalues_iq4nl_f[16] = {
- -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
- };
- constexpr constant static float kvalues_mxfp4_f[16] = {
- 0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f
- };
- static inline int best_index_int8(int n, constant float * val, float x) {
- if (x <= val[0]) return 0;
- if (x >= val[n-1]) return n-1;
- int ml = 0, mu = n-1;
- while (mu-ml > 1) {
- int mav = (ml+mu)/2;
- if (x < val[mav]) mu = mav; else ml = mav;
- }
- return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
- }
- static inline float e8m0_to_fp32(uint8_t x) {
- uint32_t bits;
- if (x == 0) {
- bits = 0x00400000;
- } else {
- bits = (uint32_t) x << 23;
- }
- return as_type<float>(bits);
- }
- static inline float dot(float x, float y) {
- return x*y;
- }
- // NOTE: this is not dequantizing - we are simply fitting the template
- template <typename type4x4>
- void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
- reg = (type4x4)(*src);
- }
- template <typename type4>
- void dequantize_f32_t4(device const float4 * src, short il, thread type4 & reg) {
- reg = (type4)(*src);
- }
- template <typename type4x4>
- void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
- reg = (type4x4)(*src);
- }
- template <typename type4>
- void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
- reg = (type4)(*(src));
- }
- #if defined(GGML_METAL_HAS_BF16)
- template <typename type4x4>
- void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
- reg = (type4x4)(*src);
- }
- template <typename type4>
- void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg) {
- reg = (type4)(*(src));
- }
- #endif
- template <typename type4x4>
- void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
- device const uint16_t * qs = ((device const uint16_t *)xb + 1);
- const float d1 = il ? (xb->d / 16.h) : xb->d;
- const float d2 = d1 / 256.f;
- const float md = -8.h * xb->d;
- const ushort mask0 = il ? 0x00F0 : 0x000F;
- const ushort mask1 = mask0 << 8;
- float4x4 reg_f;
- for (int i = 0; i < 8; i++) {
- reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md;
- reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md;
- }
- reg = (type4x4) reg_f;
- }
- template <typename type4>
- void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & reg) {
- device const uint16_t * qs = ((device const uint16_t *)xb + 1);
- const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
- const float d2 = d1 / 256.f;
- const float md = -8.h * xb->d;
- const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
- const ushort mask1 = mask0 << 8;
- for (int i = 0; i < 2; i++) {
- reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + md;
- reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + md;
- }
- }
- void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
- #pragma METAL fp math_mode(safe)
- float amax = 0.0f; // absolute max
- float max = 0.0f;
- for (int j = 0; j < QK4_0; j++) {
- const float v = src[j];
- if (amax < fabs(v)) {
- amax = fabs(v);
- max = v;
- }
- }
- const float d = max / -8;
- const float id = d ? 1.0f/d : 0.0f;
- dst.d = d;
- for (int j = 0; j < QK4_0/2; ++j) {
- const float x0 = src[0 + j]*id;
- const float x1 = src[QK4_0/2 + j]*id;
- const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
- const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
- dst.qs[j] = xi0;
- dst.qs[j] |= xi1 << 4;
- }
- }
- void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
- #pragma METAL fp math_mode(safe)
- float min = FLT_MAX;
- float max = -FLT_MAX;
- for (int j = 0; j < QK4_1; j++) {
- const float v = src[j];
- if (min > v) min = v;
- if (max < v) max = v;
- }
- const float d = (max - min) / ((1 << 4) - 1);
- const float id = d ? 1.0f/d : 0.0f;
- dst.d = d;
- dst.m = min;
- for (int j = 0; j < QK4_1/2; ++j) {
- const float x0 = (src[0 + j] - min)*id;
- const float x1 = (src[QK4_1/2 + j] - min)*id;
- const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
- const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
- dst.qs[j] = xi0;
- dst.qs[j] |= xi1 << 4;
- }
- }
- void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
- #pragma METAL fp math_mode(safe)
- float amax = 0.0f; // absolute max
- float max = 0.0f;
- for (int j = 0; j < QK5_0; j++) {
- const float v = src[j];
- if (amax < fabs(v)) {
- amax = fabs(v);
- max = v;
- }
- }
- const float d = max / -16;
- const float id = d ? 1.0f/d : 0.0f;
- dst.d = d;
- uint32_t qh = 0;
- for (int j = 0; j < QK5_0/2; ++j) {
- const float x0 = src[0 + j]*id;
- const float x1 = src[QK5_0/2 + j]*id;
- const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
- const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
- dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
- qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
- qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
- }
- thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
- for (int j = 0; j < 4; ++j) {
- dst.qh[j] = qh8[j];
- }
- }
- void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
- #pragma METAL fp math_mode(safe)
- float max = src[0];
- float min = src[0];
- for (int j = 1; j < QK5_1; j++) {
- const float v = src[j];
- min = v < min ? v : min;
- max = v > max ? v : max;
- }
- const float d = (max - min) / 31;
- const float id = d ? 1.0f/d : 0.0f;
- dst.d = d;
- dst.m = min;
- uint32_t qh = 0;
- for (int j = 0; j < QK5_1/2; ++j) {
- const float x0 = (src[0 + j] - min)*id;
- const float x1 = (src[QK5_1/2 + j] - min)*id;
- const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
- const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
- dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
- qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
- qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
- }
- thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
- for (int j = 0; j < 4; ++j) {
- dst.qh[j] = qh8[j];
- }
- }
- void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
- #pragma METAL fp math_mode(safe)
- float amax = 0.0f; // absolute max
- for (int j = 0; j < QK8_0; j++) {
- const float v = src[j];
- amax = MAX(amax, fabs(v));
- }
- const float d = amax / ((1 << 7) - 1);
- const float id = d ? 1.0f/d : 0.0f;
- dst.d = d;
- for (int j = 0; j < QK8_0; ++j) {
- const float x0 = src[j]*id;
- dst.qs[j] = round(x0);
- }
- }
- void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
- #pragma METAL fp math_mode(safe)
- float amax = 0.0f; // absolute max
- float max = 0.0f;
- for (int j = 0; j < QK4_NL; j++) {
- const float v = src[j];
- if (amax < fabs(v)) {
- amax = fabs(v);
- max = v;
- }
- }
- const float d = max / kvalues_iq4nl_f[0];
- const float id = d ? 1.0f/d : 0.0f;
- float sumqx = 0, sumq2 = 0;
- for (int j = 0; j < QK4_NL/2; ++j) {
- const float x0 = src[0 + j]*id;
- const float x1 = src[QK4_NL/2 + j]*id;
- const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
- const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
- dst.qs[j] = xi0 | (xi1 << 4);
- const float v0 = kvalues_iq4nl_f[xi0];
- const float v1 = kvalues_iq4nl_f[xi1];
- const float w0 = src[0 + j]*src[0 + j];
- const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
- sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
- sumq2 += w0*v0*v0 + w1*v1*v1;
- }
- dst.d = sumq2 > 0 ? sumqx/sumq2 : d;
- }
- template <typename type4x4>
- void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
- device const uint16_t * qs = ((device const uint16_t *)xb + 2);
- const float d1 = il ? (xb->d / 16.h) : xb->d;
- const float d2 = d1 / 256.f;
- const float m = xb->m;
- const ushort mask0 = il ? 0x00F0 : 0x000F;
- const ushort mask1 = mask0 << 8;
- float4x4 reg_f;
- for (int i = 0; i < 8; i++) {
- reg_f[i/2][2*(i%2) + 0] = ((qs[i] & mask0) * d1) + m;
- reg_f[i/2][2*(i%2) + 1] = ((qs[i] & mask1) * d2) + m;
- }
- reg = (type4x4) reg_f;
- }
- template <typename type4>
- void dequantize_q4_1_t4(device const block_q4_1 * xb, short il, thread type4 & reg) {
- device const uint16_t * qs = ((device const uint16_t *)xb + 2);
- const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
- const float d2 = d1 / 256.f;
- const float m = xb->m;
- const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
- const ushort mask1 = mask0 << 8;
- for (int i = 0; i < 2; i++) {
- reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + m;
- reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + m;
- }
- }
- template <typename type4x4>
- void dequantize_q5_0(device const block_q5_0 * xb, short il, thread type4x4 & reg) {
- device const uint16_t * qs = ((device const uint16_t *)xb + 3);
- const float d = xb->d;
- const float md = -16.h * xb->d;
- const ushort mask = il ? 0x00F0 : 0x000F;
- const uint32_t qh = *((device const uint32_t *)xb->qh);
- const int x_mv = il ? 4 : 0;
- const int gh_mv = il ? 12 : 0;
- const int gh_bk = il ? 0 : 4;
- float4x4 reg_f;
- for (int i = 0; i < 8; i++) {
- // extract the 5-th bits for x0 and x1
- const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
- const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
- // combine the 4-bits from qs with the 5th bit
- const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
- const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
- reg_f[i/2][2*(i%2) + 0] = d * x0 + md;
- reg_f[i/2][2*(i%2) + 1] = d * x1 + md;
- }
- reg = (type4x4) reg_f;
- }
- template <typename type4>
- void dequantize_q5_0_t4(device const block_q5_0 * xb, short il, thread type4 & reg) {
- device const uint16_t * qs = ((device const uint16_t *)xb + 3);
- const float d = xb->d;
- const float md = -16.h * xb->d;
- const ushort mask = (il/4) ? 0x00F0 : 0x000F;
- const uint32_t qh = *((device const uint32_t *)xb->qh);
- const int x_mv = (il/4) ? 4 : 0;
- const int gh_mv = (il/4) ? 12 : 0;
- const int gh_bk = (il/4) ? 0 : 4;
- for (int ii = 0; ii < 2; ii++) {
- int i = 2*(il%4) + ii;
- // extract the 5-th bits for x0 and x1
- const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
- const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
- // combine the 4-bits from qs with the 5th bit
- const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
- const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
- reg[2*ii + 0] = d * x0 + md;
- reg[2*ii + 1] = d * x1 + md;
- }
- }
- template <typename type4x4>
- void dequantize_q5_1(device const block_q5_1 * xb, short il, thread type4x4 & reg) {
- device const uint16_t * qs = ((device const uint16_t *)xb + 4);
- const float d = xb->d;
- const float m = xb->m;
- const ushort mask = il ? 0x00F0 : 0x000F;
- const uint32_t qh = *((device const uint32_t *)xb->qh);
- const int x_mv = il ? 4 : 0;
- const int gh_mv = il ? 12 : 0;
- const int gh_bk = il ? 0 : 4;
- float4x4 reg_f;
- for (int i = 0; i < 8; i++) {
- // extract the 5-th bits for x0 and x1
- const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
- const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
- // combine the 4-bits from qs with the 5th bit
- const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
- const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
- reg_f[i/2][2*(i%2) + 0] = d * x0 + m;
- reg_f[i/2][2*(i%2) + 1] = d * x1 + m;
- }
- reg = (type4x4) reg_f;
- }
- template <typename type4>
- void dequantize_q5_1_t4(device const block_q5_1 * xb, short il, thread type4 & reg) {
- device const uint16_t * qs = ((device const uint16_t *)xb + 4);
- const float d = xb->d;
- const float m = xb->m;
- const ushort mask = (il/4) ? 0x00F0 : 0x000F;
- const uint32_t qh = *((device const uint32_t *)xb->qh);
- const int x_mv = (il/4) ? 4 : 0;
- const int gh_mv = (il/4) ? 12 : 0;
- const int gh_bk = (il/4) ? 0 : 4;
- for (int ii = 0; ii < 2; ii++) {
- int i = 2*(il%4) + ii;
- // extract the 5-th bits for x0 and x1
- const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
- const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
- // combine the 4-bits from qs with the 5th bit
- const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
- const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
- reg[2*ii + 0] = d * x0 + m;
- reg[2*ii + 1] = d * x1 + m;
- }
- }
- template <typename type4x4>
- void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
- device const int8_t * qs = ((device const int8_t *)xb->qs);
- const float d = xb->d;
- float4x4 reg_f;
- for (int i = 0; i < 16; i++) {
- reg_f[i/4][i%4] = (qs[i + 16*il] * d);
- }
- reg = (type4x4) reg_f;
- }
- template <typename type4>
- void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & reg) {
- device const int8_t * qs = ((device const int8_t *)xb->qs);
- const float d = xb->d;
- for (int i = 0; i < 4; i++) {
- reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);
- }
- }
- template <typename type4x4>
- void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) {
- device const uint8_t * q2 = (device const uint8_t *)xb->qs;
- const float d = e8m0_to_fp32(xb->e);
- const uint8_t shr = il >= 1 ? 4 : 0;
- for (int i = 0; i < 4; ++i) {
- reg[i][0] = d * kvalues_mxfp4_f[(q2[4*i + 0] >> shr) & 0x0F];
- reg[i][1] = d * kvalues_mxfp4_f[(q2[4*i + 1] >> shr) & 0x0F];
- reg[i][2] = d * kvalues_mxfp4_f[(q2[4*i + 2] >> shr) & 0x0F];
- reg[i][3] = d * kvalues_mxfp4_f[(q2[4*i + 3] >> shr) & 0x0F];
- }
- }
- template <typename type4>
- void dequantize_mxfp4_t4(device const block_mxfp4 * xb, short il, thread type4 & reg) {
- device const uint8_t * q2 = (device const uint8_t *)xb->qs;
- const float d = e8m0_to_fp32(xb->e);
- const short il4 = il%4;
- const uint8_t shr = il >= 4 ? 4 : 0;
- reg[0] = d * kvalues_mxfp4_f[(q2[4*il4 + 0] >> shr) & 0x0F];
- reg[1] = d * kvalues_mxfp4_f[(q2[4*il4 + 1] >> shr) & 0x0F];
- reg[2] = d * kvalues_mxfp4_f[(q2[4*il4 + 2] >> shr) & 0x0F];
- reg[3] = d * kvalues_mxfp4_f[(q2[4*il4 + 3] >> shr) & 0x0F];
- }
- template <typename type4x4>
- void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
- const float d = xb->d;
- const float min = xb->dmin;
- device const uint8_t * q = (device const uint8_t *)xb->qs;
- float dl, ml;
- uint8_t sc = xb->scales[il];
- q = q + 32*(il/8) + 16*(il&1);
- il = (il/2)%4;
- half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
- uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
- dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
- for (int i = 0; i < 16; ++i) {
- reg[i/4][i%4] = dl * (q[i] & mask) - ml;
- }
- }
- template <typename type4x4>
- void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
- const half d_all = xb->d;
- device const uint8_t * q = (device const uint8_t *)xb->qs;
- device const uint8_t * h = (device const uint8_t *)xb->hmask;
- device const int8_t * scales = (device const int8_t *)xb->scales;
- q = q + 32 * (il/8) + 16 * (il&1);
- h = h + 16 * (il&1);
- uint8_t m = 1 << (il/2);
- uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
- ((il/4)>0 ? 12 : 3);
- uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
- uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
- int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
- : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
- float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
- const float ml = 4.f * dl;
- il = (il/2) & 3;
- const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
- const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
- dl *= coef;
- for (int i = 0; i < 16; ++i) {
- reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
- }
- }
- static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
- return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
- : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
- }
- template <typename type4x4>
- void dequantize_q4_K(device const block_q4_K * xb, short il, thread type4x4 & reg) {
- device const uchar * q = xb->qs;
- short is = (il/4) * 2;
- q = q + (il/4) * 32 + 16 * (il&1);
- il = il & 3;
- const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
- const float d = il < 2 ? xb->d : xb->d / 16.h;
- const float min = xb->dmin;
- const float dl = d * sc[0];
- const float ml = min * sc[1];
- const ushort mask = il < 2 ? 0x0F : 0xF0;
- for (int i = 0; i < 16; ++i) {
- reg[i/4][i%4] = dl * (q[i] & mask) - ml;
- }
- }
- template <typename type4x4>
- void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
- device const uint8_t * q = xb->qs;
- device const uint8_t * qh = xb->qh;
- short is = (il/4) * 2;
- q = q + 32 * (il/4) + 16 * (il&1);
- qh = qh + 16 * (il&1);
- uint8_t ul = 1 << (il/2);
- il = il & 3;
- const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
- const float d = il < 2 ? xb->d : xb->d / 16.f;
- const float min = xb->dmin;
- const float dl = d * sc[0];
- const float ml = min * sc[1];
- const ushort mask = il<2 ? 0x0F : 0xF0;
- const float qh_val = il<2 ? 16.f : 256.f;
- for (int i = 0; i < 16; ++i) {
- reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
- }
- }
- template <typename type4x4>
- void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
- const half d_all = xb->d;
- device const uint16_t * ql = (device const uint16_t *)xb->ql;
- device const uint16_t * qh = (device const uint16_t *)xb->qh;
- device const int8_t * scales = (device const int8_t *)xb->scales;
- ql = ql + 32*(il/8) + 16*((il/2)&1) + 8*(il&1);
- qh = qh + 16*(il/8) + 8*(il&1);
- float sc = scales[(il%2) + 2 * ((il/2))];
- il = (il/2) & 3;
- const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030) : (il>0 ? 0x0C0C0C0C : 0x03030303);
- const uint32_t kmask2 = il>1 ? 0xF0F0F0F0 : 0x0F0F0F0F;
- const float ml = d_all * sc * 32.f;
- const float dl0 = d_all * sc;
- const float dl1 = dl0 / 256.f;
- const float dl2 = dl0 / (256.f * 256.f);
- const float dl3 = dl0 / (256.f * 256.f * 256.f);
- const uint8_t shr_h = il>2 ? 2 : 0;
- const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4);
- const uint8_t shr_l = il>1 ? 4 : 0;
- for (int i = 0; i < 4; ++i) {
- const uint32_t low = (ql[2*i] | (uint32_t)(ql[2*i+1] << 16)) & kmask2;
- const uint32_t high = (qh[2*i] | (uint32_t)(qh[2*i+1] << 16)) & kmask1;
- const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l);
- reg[i][0] = dl0 * ((half)(q & 0xFF)) - ml;
- reg[i][1] = dl1 * ((float)(q & 0xFF00)) - ml;
- reg[i][2] = dl2 * ((float)(q & 0xFF0000)) - ml;
- reg[i][3] = dl3 * ((float)(q & 0xFF000000)) - ml;
- }
- }
- template <typename type4x4>
- void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
- const float d = xb->d;
- const int ib32 = il/2;
- il = il%2;
- // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
- // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
- device const uint16_t * q2 = xb->qs + 4*ib32;
- const uint32_t aux32_g = q2[0] | (q2[1] << 16);
- const uint32_t aux32_s = q2[2] | (q2[3] << 16);
- thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
- const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
- constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
- uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
- for (int i = 0; i < 8; ++i) {
- reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
- }
- grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
- signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
- for (int i = 0; i < 8; ++i) {
- reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
- }
- }
- template <typename type4x4>
- void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
- const float d = xb->d;
- const int ib32 = il/2;
- il = il%2;
- // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
- device const uint16_t * q2 = xb->qs + 4*ib32;
- const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
- constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
- uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
- for (int i = 0; i < 8; ++i) {
- reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
- }
- grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
- signs = ksigns_iq2xs[q2[2*il+1] >> 9];
- for (int i = 0; i < 8; ++i) {
- reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
- }
- }
- template <typename type4x4>
- void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
- const float d = xb->d;
- const int ib32 = il/2;
- il = il%2;
- // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
- device const uint8_t * q3 = xb->qs + 8*ib32;
- device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
- const uint32_t aux32 = gas[0] | (gas[1] << 16);
- const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
- constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
- constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
- uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
- for (int i = 0; i < 4; ++i) {
- reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
- reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
- }
- grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
- grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
- signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
- for (int i = 0; i < 4; ++i) {
- reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
- reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
- }
- }
- template <typename type4x4>
- void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
- const float d = xb->d;
- const int ib32 = il/2;
- il = il%2;
- // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
- device const uint8_t * qs = xb->qs + 8*ib32;
- device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
- const uint8_t qh = xb->qh[ib32] >> 4*il;
- const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
- constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
- constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
- for (int i = 0; i < 4; ++i) {
- reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
- reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
- }
- grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
- grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
- for (int i = 0; i < 4; ++i) {
- reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
- reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
- }
- }
- template <typename type4x4>
- void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
- const float d = xb->d;
- const int ib32 = il/2;
- il = il%2;
- // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
- device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
- device const uint8_t * signs = qs + QK_K/8;
- const uint8_t qh = xb->qh[ib32] >> 4*il;
- const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
- constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));
- constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));
- for (int i = 0; i < 8; ++i) {
- reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);
- reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);
- }
- }
- template <typename type4x4>
- void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
- const int ib32 = il/2;
- il = il%2;
- const float d = xb->d;
- device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
- device const uint16_t * qh = xb->qh;
- const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1);
- const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA);
- const uint16_t h = qh[ib32] >> 6*il;
- constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700)));
- constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700)));
- for (int i = 0; i < 4; ++i) {
- reg[0][i] = dl * (grid1[i] & 0xf) + ml;
- reg[1][i] = dl * (grid1[i] >> 4) + ml;
- reg[2][i] = dl * (grid2[i] & 0xf) + ml;
- reg[3][i] = dl * (grid2[i] >> 4) + ml;
- }
- }
- template <typename type4x4>
- void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
- const int ib32 = il/2;
- il = il%2;
- device const uint16_t * sc = (device const uint16_t *)xb->scales;
- iq1m_scale_t scale;
- scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
- const float d = scale.f16;
- device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
- device const uint8_t * qh = xb->qh + 2*ib32 + il;
- const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
- const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
- const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
- constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
- constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
- for (int i = 0; i < 4; ++i) {
- reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
- reg[1][i] = dl * (grid1[i] >> 4) + ml1;
- reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
- reg[3][i] = dl * (grid2[i] >> 4) + ml2;
- }
- }
- template <typename type4x4>
- void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
- device const uint16_t * q4 = (device const uint16_t *)xb->qs;
- const float d = xb->d;
- uint32_t aux32;
- thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
- for (int i = 0; i < 4; ++i) {
- aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
- reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
- reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
- reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
- reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
- }
- }
- template <typename type4>
- void dequantize_iq4_nl_t4(device const block_iq4_nl * xb, short il, thread type4 & reg) {
- device const uint16_t * q4 = (device const uint16_t *)xb->qs;
- const float d = xb->d;
- uint32_t aux32;
- thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
- aux32 = ((q4[2*(il%4)] | (q4[2*(il%4)+1] << 16)) >> 4*(il/4)) & 0x0f0f0f0f;
- reg[0] = d * kvalues_iq4nl_f[q8[0]];
- reg[1] = d * kvalues_iq4nl_f[q8[1]];
- reg[2] = d * kvalues_iq4nl_f[q8[2]];
- reg[3] = d * kvalues_iq4nl_f[q8[3]];
- }
- template <typename type4x4>
- void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
- const int ib32 = il/2;
- il = il%2;
- // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
- device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
- const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
- const float d = (float)xb->d * (ls - 32);
- uint32_t aux32;
- thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
- for (int i = 0; i < 4; ++i) {
- aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
- reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
- reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
- reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
- reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
- }
- }
- enum ggml_sort_order {
- GGML_SORT_ORDER_ASC,
- GGML_SORT_ORDER_DESC,
- };
- // general-purpose kernel for addition, subtraction, multiplication and division of two tensors
- // pros: works for non-contiguous tensors, supports broadcast across all dims
- // cons: not very efficient
- template <int F>
- kernel void kernel_add_fuse_impl(
- constant ggml_metal_kargs_bin & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort3 tpitg[[thread_position_in_threadgroup]],
- ushort3 ntg[[threads_per_threadgroup]]) {
- const int i03 = tgpig.z;
- const int i02 = tgpig.y;
- const int i01 = tgpig.x;
- const int i13 = i03%args.ne13;
- const int i12 = i02%args.ne12;
- const int i11 = i01%args.ne11;
- device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
- device float * dst_ptr = (device float *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
- device const float * src1_ptr[F];
- for (short j = 0; j < F; ++j) {
- src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
- }
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
- const int i10 = i0%args.ne10;
- float res = src0_ptr[i0];
- #pragma unroll
- for (short j = 0; j < F; ++j) {
- res += src1_ptr[j][i10];
- }
- dst_ptr[i0] = res;
- }
- }
- typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t;
- template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>;
- template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>;
- template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>;
- template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>;
- template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>;
- template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>;
- template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>;
- template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>;
- kernel void kernel_sub_fuse_1(
- constant ggml_metal_kargs_bin & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort3 tpitg[[thread_position_in_threadgroup]],
- ushort3 ntg[[threads_per_threadgroup]]) {
- const int i03 = tgpig.z;
- const int i02 = tgpig.y;
- const int i01 = tgpig.x;
- const int i13 = i03%args.ne13;
- const int i12 = i02%args.ne12;
- const int i11 = i01%args.ne11;
- device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
- device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
- device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
- const int i10 = i0%args.ne10;
- *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10));
- }
- }
- kernel void kernel_mul_fuse_1(
- constant ggml_metal_kargs_bin & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort3 tpitg[[thread_position_in_threadgroup]],
- ushort3 ntg[[threads_per_threadgroup]]) {
- const int i03 = tgpig.z;
- const int i02 = tgpig.y;
- const int i01 = tgpig.x;
- const int i13 = i03%args.ne13;
- const int i12 = i02%args.ne12;
- const int i11 = i01%args.ne11;
- device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
- device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
- device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
- if (args.ne10 == 1) {
- const float x = *((device float *)(src1_ptr));
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
- *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
- }
- } else {
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
- const int i10 = i0%args.ne10;
- *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
- }
- }
- }
- kernel void kernel_div_fuse_1(
- constant ggml_metal_kargs_bin & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort3 tpitg[[thread_position_in_threadgroup]],
- ushort3 ntg[[threads_per_threadgroup]]) {
- const int i03 = tgpig.z;
- const int i02 = tgpig.y;
- const int i01 = tgpig.x;
- const int i13 = i03%args.ne13;
- const int i12 = i02%args.ne12;
- const int i11 = i01%args.ne11;
- device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
- device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
- device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
- if (args.ne10 == 1) {
- const float x = 1.0f / *((device float *)(src1_ptr));
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
- *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
- }
- } else {
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
- const int i10 = i0%args.ne10;
- *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
- }
- }
- }
- kernel void kernel_add_id(
- constant ggml_metal_kargs_add_id & args,
- device const char * src0,
- device const char * src1,
- device const char * src2,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort3 tpitg[[thread_position_in_threadgroup]],
- ushort3 ntg[[threads_per_threadgroup]]) {
- const int i1 = tgpig.x;
- const int i2 = tgpig.y;
- const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21));
- const size_t nb1 = args.ne0 * sizeof(float);
- const size_t nb2 = args.ne1 * nb1;
- device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2);
- device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02);
- device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
- dst_row[i0] = src0_row[i0] + src1_row[i0];
- }
- }
- template<typename T>
- kernel void kernel_repeat(
- constant ggml_metal_kargs_repeat & args,
- device const char * src0,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort3 tpitg[[thread_position_in_threadgroup]],
- ushort3 ntg[[threads_per_threadgroup]]) {
- const int i3 = tgpig.z;
- const int i2 = tgpig.y;
- const int i1 = tgpig.x;
- const int i03 = i3%args.ne03;
- const int i02 = i2%args.ne02;
- const int i01 = i1%args.ne01;
- device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
- device char * dst_ptr = dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1;
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
- const int i00 = i0%args.ne00;
- *((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00));
- }
- }
- typedef decltype(kernel_repeat<float>) kernel_repeat_t;
- template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
- template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
- template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
- template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
- // assumption: src1 is a row
- // broadcast src1 into src0
- template <short F>
- kernel void kernel_add_row_c4_fuse_impl(
- constant ggml_metal_kargs_bin & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint tpig[[thread_position_in_grid]]) {
- const uint nb = args.ne00/4;
- const uint i = tpig % nb;
- device const float4 * src0_row = (device const float4 *) (src0);
- device float4 * dst_row = (device float4 *) (dst);
- float4 res = src0_row[tpig];
- #pragma unroll(F)
- for (short j = 0; j < F; ++j) {
- res += ((device const float4 *) (src1 + args.o1[j]))[i];
- }
- dst_row[tpig] = res;
- }
- typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t;
- template [[host_name("kernel_add_row_c4_fuse_1")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>;
- template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>;
- template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>;
- template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>;
- template [[host_name("kernel_add_row_c4_fuse_5")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<5>;
- template [[host_name("kernel_add_row_c4_fuse_6")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<6>;
- template [[host_name("kernel_add_row_c4_fuse_7")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<7>;
- template [[host_name("kernel_add_row_c4_fuse_8")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<8>;
- template <short F>
- kernel void kernel_sub_row_c4_fuse_impl(
- constant ggml_metal_kargs_bin & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint tpig[[thread_position_in_grid]]) {
- const uint nb = args.ne00/4;
- const uint i = tpig % nb;
- device const float4 * src0_row = (device const float4 *) (src0);
- device float4 * dst_row = (device float4 *) (dst);
- device const float4 * src1_row[F];
- for (short j = 0; j < F; ++j) {
- src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
- }
- float4 res = src0_row[tpig];
- #pragma unroll(F)
- for (short j = 0; j < F; ++j) {
- res -= src1_row[j][i];
- }
- dst_row[tpig] = res;
- }
- typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t;
- template [[host_name("kernel_sub_row_c4_fuse_1")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>;
- template <short F>
- kernel void kernel_mul_row_c4_fuse_impl(
- constant ggml_metal_kargs_bin & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint tpig[[thread_position_in_grid]]) {
- const uint nb = args.ne00/4;
- const uint i = tpig % nb;
- device const float4 * src0_row = (device const float4 *) (src0);
- device float4 * dst_row = (device float4 *) (dst);
- device const float4 * src1_row[F];
- for (short j = 0; j < F; ++j) {
- src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
- }
- float4 res = src0_row[tpig];
- #pragma unroll(F)
- for (short j = 0; j < F; ++j) {
- res *= src1_row[j][i];
- }
- dst_row[tpig] = res;
- }
- typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t;
- template [[host_name("kernel_mul_row_c4_fuse_1")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>;
- template <short F>
- kernel void kernel_div_row_c4_fuse_impl(
- constant ggml_metal_kargs_bin & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint tpig[[thread_position_in_grid]]) {
- const uint nb = args.ne00/4;
- const uint i = tpig % nb;
- device const float4 * src0_row = (device const float4 *) (src0);
- device float4 * dst_row = (device float4 *) (dst);
- device const float4 * src1_row[F];
- for (short j = 0; j < F; ++j) {
- src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
- }
- float4 res = src0_row[tpig];
- #pragma unroll(F)
- for (short j = 0; j < F; ++j) {
- res /= src1_row[j][i];
- }
- dst_row[tpig] = res;
- }
- typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t;
- template [[host_name("kernel_div_row_c4_fuse_1")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>;
- kernel void kernel_scale_f32(
- constant ggml_metal_kargs_scale & args,
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = src0[tpig] * args.scale + args.bias;
- }
- kernel void kernel_scale_f32_4(
- constant ggml_metal_kargs_scale & args,
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = src0[tpig] * args.scale + args.bias;
- }
- kernel void kernel_clamp_f32(
- constant ggml_metal_kargs_clamp & args,
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = clamp(src0[tpig], args.min, args.max);
- }
- kernel void kernel_clamp_f32_4(
- constant ggml_metal_kargs_clamp & args,
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = clamp(src0[tpig], args.min, args.max);
- }
- kernel void kernel_relu_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = max(0.0f, src0[tpig]);
- }
- kernel void kernel_relu_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = max(0.0f, src0[tpig]);
- }
- kernel void kernel_sigmoid_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
- }
- kernel void kernel_sigmoid_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
- }
- kernel void kernel_tanh_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = precise::tanh(src0[tpig]);
- }
- kernel void kernel_tanh_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = precise::tanh(src0[tpig]);
- }
- constant float GELU_COEF_A = 0.044715f;
- constant float GELU_QUICK_COEF = -1.702f;
- constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
- constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
- kernel void kernel_gelu_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float & x = src0[tpig];
- dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
- }
- kernel void kernel_gelu_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float4 & x = src0[tpig];
- // BEWARE !!!
- // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
- // This was observed with Falcon 7B and 40B models
- //
- dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
- }
- kernel void kernel_gelu_quick_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float & x = src0[tpig];
- dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
- }
- kernel void kernel_gelu_quick_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float4 & x = src0[tpig];
- dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
- }
- // based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
- // ref: https://www.johndcook.com/blog/python_erf/
- constant float p_erf = 0.3275911f;
- constant float a1_erf = 0.254829592f;
- constant float a2_erf = -0.284496736f;
- constant float a3_erf = 1.421413741f;
- constant float a4_erf = -1.453152027f;
- constant float a5_erf = 1.061405429f;
- template<typename T>
- T erf_approx(T x) {
- T sign_x = sign(x);
- x = fabs(x);
- T t = 1.0f / (1.0f + p_erf * x);
- T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
- return sign_x * y;
- }
- kernel void kernel_gelu_erf_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float & x = src0[tpig];
- dst[tpig] = 0.5f*x*(1.0f+erf_approx<float>(x*SQRT_2_INV));
- }
- kernel void kernel_gelu_erf_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float4 & x = src0[tpig];
- dst[tpig] = 0.5f*x*(1.0f+erf_approx<float4>(x*SQRT_2_INV));
- }
- kernel void kernel_silu_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float & x = src0[tpig];
- dst[tpig] = x / (1.0f + exp(-x));
- }
- kernel void kernel_silu_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float4 & x = src0[tpig];
- dst[tpig] = x / (1.0f + exp(-x));
- }
- kernel void kernel_elu_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- const float x = src0[tpig];
- dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f);
- }
- kernel void kernel_elu_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- const float4 x = src0[tpig];
- dst[tpig][0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
- dst[tpig][1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
- dst[tpig][2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
- dst[tpig][3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
- }
- kernel void kernel_sqr_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = src0[tpig] * src0[tpig];
- }
- kernel void kernel_sqr_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = src0[tpig] * src0[tpig];
- }
- kernel void kernel_sqrt_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = sqrt(src0[tpig]);
- }
- kernel void kernel_sqrt_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = sqrt(src0[tpig]);
- }
- kernel void kernel_sin_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = sin(src0[tpig]);
- }
- kernel void kernel_sin_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = sin(src0[tpig]);
- }
- kernel void kernel_cos_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = cos(src0[tpig]);
- }
- kernel void kernel_cos_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = cos(src0[tpig]);
- }
- kernel void kernel_log_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = log(src0[tpig]);
- }
- kernel void kernel_log_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = log(src0[tpig]);
- }
- kernel void kernel_neg_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = -src0[tpig];
- }
- kernel void kernel_neg_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = -src0[tpig];
- }
- kernel void kernel_abs_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = fabs(src0[tpig]);
- }
- kernel void kernel_abs_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = fabs(src0[tpig]);
- }
- kernel void kernel_sgn_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = sign(src0[tpig]);
- }
- kernel void kernel_sgn_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = sign(src0[tpig]);
- }
- kernel void kernel_step_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = step(0.0f, src0[tpig]);
- }
- kernel void kernel_step_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = step(0.0f, src0[tpig]);
- }
- kernel void kernel_hardswish_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- const float x = src0[tpig];
- dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
- }
- kernel void kernel_hardswish_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- const float4 x = src0[tpig];
- dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
- }
- kernel void kernel_hardsigmoid_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- const float x = src0[tpig];
- dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
- }
- kernel void kernel_hardsigmoid_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- const float4 x = src0[tpig];
- dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
- }
- kernel void kernel_exp_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = exp(src0[tpig]);
- }
- kernel void kernel_exp_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = exp(src0[tpig]);
- }
- kernel void kernel_reglu_f32(
- constant ggml_metal_kargs_glu & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint tgpig[[threadgroup_position_in_grid]],
- uint tpitg[[thread_position_in_threadgroup]],
- uint ntg[[threads_per_threadgroup]]) {
- device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
- device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
- device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
- for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
- const float x0 = src0_row[i0];
- const float x1 = src1_row[i0];
- dst_row[i0] = x0*x1*(x0 > 0.0f);
- }
- }
- kernel void kernel_geglu_f32(
- constant ggml_metal_kargs_glu & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint tgpig[[threadgroup_position_in_grid]],
- uint tpitg[[thread_position_in_threadgroup]],
- uint ntg[[threads_per_threadgroup]]) {
- device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
- device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
- device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
- for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
- const float x0 = src0_row[i0];
- const float x1 = src1_row[i0];
- const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
- dst_row[i0] = gelu*x1;
- }
- }
- kernel void kernel_swiglu_f32(
- constant ggml_metal_kargs_glu & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint tgpig[[threadgroup_position_in_grid]],
- uint tpitg[[thread_position_in_threadgroup]],
- uint ntg[[threads_per_threadgroup]]) {
- device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
- device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
- device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
- for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
- const float x0 = src0_row[i0];
- const float x1 = src1_row[i0];
- const float silu = x0 / (1.0f + exp(-x0));
- dst_row[i0] = silu*x1;
- }
- }
- kernel void kernel_swiglu_oai_f32(
- constant ggml_metal_kargs_glu & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint tgpig[[threadgroup_position_in_grid]],
- uint tpitg[[thread_position_in_threadgroup]],
- uint ntg[[threads_per_threadgroup]]) {
- device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
- device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
- device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
- for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
- float x0 = src0_row[i0];
- float x1 = src1_row[i0];
- x0 = min(x0, args.limit);
- x1 = max(min(x1, args.limit), -args.limit);
- float out_glu = x0 / (1.0f + exp(-x0 * args.alpha));
- out_glu = out_glu * (1.0f + x1);
- dst_row[i0] = out_glu;
- }
- }
- kernel void kernel_geglu_erf_f32(
- constant ggml_metal_kargs_glu & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint tgpig[[threadgroup_position_in_grid]],
- uint tpitg[[thread_position_in_threadgroup]],
- uint ntg[[threads_per_threadgroup]]) {
- device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
- device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
- device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
- for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
- const float x0 = src0_row[i0];
- const float x1 = src1_row[i0];
- const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV));
- dst_row[i0] = gelu_erf*x1;
- }
- }
- kernel void kernel_geglu_quick_f32(
- constant ggml_metal_kargs_glu & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint tgpig[[threadgroup_position_in_grid]],
- uint tpitg[[thread_position_in_threadgroup]],
- uint ntg[[threads_per_threadgroup]]) {
- device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
- device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
- device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
- for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
- const float x0 = src0_row[i0];
- const float x1 = src1_row[i0];
- const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0)));
- dst_row[i0] = gelu_quick*x1;
- }
- }
- kernel void kernel_op_sum_f32(
- constant ggml_metal_kargs_sum & args,
- device const float * src0,
- device float * dst,
- threadgroup float * shmem_f32 [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort3 tpitg[[thread_position_in_threadgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort3 ntg[[threads_per_threadgroup]]) {
- if (args.np == 0) {
- return;
- }
- const uint nsg = (ntg.x + 31) / 32;
- float sumf = 0;
- for (uint64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) {
- sumf += src0[i0];
- }
- sumf = simd_sum(sumf);
- if (tiisg == 0) {
- shmem_f32[sgitg] = sumf;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- float total = 0;
- if (sgitg == 0) {
- float v = 0;
- if (tpitg.x < nsg) {
- v = shmem_f32[tpitg.x];
- }
- total = simd_sum(v);
- if (tpitg.x == 0) {
- dst[0] = total;
- }
- }
- }
- template <bool norm>
- kernel void kernel_sum_rows(
- constant ggml_metal_kargs_sum_rows & args,
- device const float * src0,
- device float * dst,
- threadgroup float * shmem_f32 [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort3 tpitg[[thread_position_in_threadgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort3 ntg[[threads_per_threadgroup]]) {
- int64_t i3 = tgpig.z;
- int64_t i2 = tgpig.y;
- int64_t i1 = tgpig.x;
- if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
- return;
- }
- if (sgitg == 0) {
- shmem_f32[tiisg] = 0.0f;
- }
- device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
- device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
- float sumf = 0;
- for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
- sumf += src_row[i0];
- }
- sumf = simd_sum(sumf);
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (tiisg == 0) {
- shmem_f32[sgitg] = sumf;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- sumf = shmem_f32[tiisg];
- sumf = simd_sum(sumf);
- if (tpitg.x == 0) {
- dst_row[0] = norm ? sumf / args.ne00 : sumf;
- }
- }
- typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
- template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
- template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
- template<typename T>
- kernel void kernel_cumsum_blk(
- constant ggml_metal_kargs_cumsum_blk & args,
- device const char * src0,
- device char * tmp,
- device char * dst,
- threadgroup char * shmem [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort3 tpitg[[thread_position_in_threadgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort3 ntg[[threads_per_threadgroup]]) {
- const int ib = tgpig[0]/args.ne01;
- const int i00 = ib*ntg.x;
- const int i01 = tgpig[0]%args.ne01;
- const int i02 = tgpig[1];
- const int i03 = tgpig[2];
- device const float * src0_row = (device const float *) (src0 +
- args.nb01*i01 +
- args.nb02*i02 +
- args.nb03*i03);
- threadgroup float * shmem_f32 = (threadgroup float *) shmem;
- float v = 0.0f;
- if (i00 + tpitg.x < args.ne00) {
- v = src0_row[i00 + tpitg.x];
- }
- float s = simd_prefix_inclusive_sum(v);
- if (tiisg == N_SIMDWIDTH - 1) {
- shmem_f32[sgitg] = s;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (sgitg == 0) {
- shmem_f32[tiisg] = simd_prefix_exclusive_sum(shmem_f32[tiisg]);
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- s += shmem_f32[sgitg];
- device float * dst_row = (device float *) dst +
- args.ne00*i01 +
- args.ne00*args.ne01*i02 +
- args.ne00*args.ne01*args.ne02*i03;
- if (i00 + tpitg.x < args.ne00) {
- dst_row[i00 + tpitg.x] = s;
- }
- if (args.outb && tpitg.x == ntg.x - 1) {
- device float * tmp_row = (device float *) tmp +
- args.net0*i01 +
- args.net0*args.net1*i02 +
- args.net0*args.net1*args.net2*i03;
- tmp_row[ib] = s;
- }
- }
- typedef decltype(kernel_cumsum_blk<float>) kernel_cumsum_blk_t;
- template [[host_name("kernel_cumsum_blk_f32")]] kernel kernel_cumsum_blk_t kernel_cumsum_blk<float>;
- template<typename T>
- kernel void kernel_cumsum_add(
- constant ggml_metal_kargs_cumsum_add & args,
- device const char * tmp,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort3 tpitg[[thread_position_in_threadgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort3 ntg[[threads_per_threadgroup]]) {
- const int ib = tgpig[0]/args.ne01;
- if (ib == 0) {
- return;
- }
- const int i00 = ib*ntg.x;
- const int i01 = tgpig[0]%args.ne01;
- const int i02 = tgpig[1];
- const int i03 = tgpig[2];
- device const float * tmp_row = (device const float *) (tmp +
- args.nbt1*i01 +
- args.nbt2*i02 +
- args.nbt3*i03);
- device float * dst_row = (device float *) dst +
- args.ne00*i01 +
- args.ne00*args.ne01*i02 +
- args.ne00*args.ne01*args.ne02*i03;
- if (i00 + tpitg.x < args.ne00) {
- dst_row[i00 + tpitg.x] += tmp_row[ib - 1];
- }
- }
- typedef decltype(kernel_cumsum_add<float>) kernel_cumsum_add_t;
- template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add<float>;
- template<typename T>
- kernel void kernel_soft_max(
- constant ggml_metal_kargs_soft_max & args,
- device const char * src0,
- device const char * src1,
- device const char * src2,
- device char * dst,
- threadgroup float * buf [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint3 tptg[[threads_per_threadgroup]]) {
- const int32_t i03 = tgpig.z;
- const int32_t i02 = tgpig.y;
- const int32_t i01 = tgpig.x;
- const int32_t i13 = i03%args.ne13;
- const int32_t i12 = i02%args.ne12;
- const int32_t i11 = i01;
- device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
- device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
- device const float * psrc2 = src2 != src0 ? (device const float *) (src2) : nullptr;
- device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
- float slope = 1.0f;
- // ALiBi
- if (args.max_bias > 0.0f) {
- const int32_t h = i02;
- const float base = h < args.n_head_log2 ? args.m0 : args.m1;
- const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
- slope = pow(base, exp);
- }
- // parallel max
- float lmax = psrc2 ? psrc2[i02] : -INFINITY;
- for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
- lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
- }
- // find the max value in the block
- float max_val = simd_max(lmax);
- if (tptg.x > N_SIMDWIDTH) {
- if (sgitg == 0) {
- buf[tiisg] = -INFINITY;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (tiisg == 0) {
- buf[sgitg] = max_val;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- max_val = buf[tiisg];
- max_val = simd_max(max_val);
- }
- // parallel sum
- float lsum = 0.0f;
- for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
- const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
- lsum += exp_psrc0;
- pdst[i00] = exp_psrc0;
- }
- // This barrier fixes a failing test
- // ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
- threadgroup_barrier(mem_flags::mem_none);
- float sum = simd_sum(lsum);
- if (tptg.x > N_SIMDWIDTH) {
- if (sgitg == 0) {
- buf[tiisg] = 0.0f;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (tiisg == 0) {
- buf[sgitg] = sum;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- sum = buf[tiisg];
- sum = simd_sum(sum);
- }
- if (psrc2) {
- sum += exp(psrc2[i02] - max_val);
- }
- const float inv_sum = 1.0f/sum;
- for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
- pdst[i00] *= inv_sum;
- }
- }
- template<typename T>
- kernel void kernel_soft_max_4(
- constant ggml_metal_kargs_soft_max & args,
- device const char * src0,
- device const char * src1,
- device const char * src2,
- device char * dst,
- threadgroup float * buf [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint3 tptg[[threads_per_threadgroup]]) {
- const int32_t i03 = tgpig.z;
- const int32_t i02 = tgpig.y;
- const int32_t i01 = tgpig.x;
- const int32_t i13 = i03%args.ne13;
- const int32_t i12 = i02%args.ne12;
- const int32_t i11 = i01;
- device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
- device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
- device const float * psrc2 = src2 != src0 ? (device const float * ) (src2) : nullptr;
- device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
- float slope = 1.0f;
- if (args.max_bias > 0.0f) {
- const int32_t h = i02;
- const float base = h < args.n_head_log2 ? args.m0 : args.m1;
- const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
- slope = pow(base, exp);
- }
- // parallel max
- float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
- for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
- lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
- }
- const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
- float max_val = simd_max(lmax);
- if (tptg.x > N_SIMDWIDTH) {
- if (sgitg == 0) {
- buf[tiisg] = -INFINITY;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (tiisg == 0) {
- buf[sgitg] = max_val;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- max_val = buf[tiisg];
- max_val = simd_max(max_val);
- }
- // parallel sum
- float4 lsum4 = 0.0f;
- for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
- const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
- lsum4 += exp_psrc4;
- pdst4[i00] = exp_psrc4;
- }
- const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
- // This barrier fixes a failing test
- // ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
- threadgroup_barrier(mem_flags::mem_none);
- float sum = simd_sum(lsum);
- if (tptg.x > N_SIMDWIDTH) {
- if (sgitg == 0) {
- buf[tiisg] = 0.0f;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (tiisg == 0) {
- buf[sgitg] = sum;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- sum = buf[tiisg];
- sum = simd_sum(sum);
- }
- if (psrc2) {
- sum += exp(psrc2[i02] - max_val);
- }
- const float inv_sum = 1.0f/sum;
- for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
- pdst4[i00] *= inv_sum;
- }
- }
- typedef decltype(kernel_soft_max<float>) kernel_soft_max_t;
- typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;
- template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max<half>;
- template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max<float>;
- template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
- template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
- // ref: ggml.c:ggml_compute_forward_ssm_conv_f32
- kernel void kernel_ssm_conv_f32_f32(
- constant ggml_metal_kargs_ssm_conv & args,
- device const void * src0,
- device const void * src1,
- device float * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t ir = tgpig.x;
- const int64_t i2 = tgpig.y;
- const int64_t i3 = tgpig.z;
- const int64_t nc = args.ne10;
- //const int64_t ncs = args.ne00;
- //const int64_t nr = args.ne01;
- //const int64_t n_t = args.ne1;
- //const int64_t n_s = args.ne2;
- device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
- device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
- device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
- float sumf = 0.0f;
- for (int64_t i0 = 0; i0 < nc; ++i0) {
- sumf += s[i0] * c[i0];
- }
- x[0] = sumf;
- }
- kernel void kernel_ssm_conv_f32_f32_4(
- constant ggml_metal_kargs_ssm_conv & args,
- device const void * src0,
- device const void * src1,
- device float * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t ir = tgpig.x;
- const int64_t i2 = tgpig.y;
- const int64_t i3 = tgpig.z;
- const int64_t nc = args.ne10;
- //const int64_t ncs = args.ne00;
- //const int64_t nr = args.ne01;
- //const int64_t n_t = args.ne1;
- //const int64_t n_s = args.ne2;
- device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
- device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
- device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
- float sumf = 0.0f;
- for (int64_t i0 = 0; i0 < nc/4; ++i0) {
- sumf += dot(s[i0], c[i0]);
- }
- x[0] = sumf;
- }
- // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
- kernel void kernel_ssm_scan_f32(
- constant ggml_metal_kargs_ssm_scan & args,
- device const void * src0,
- device const void * src1,
- device const void * src2,
- device const void * src3,
- device const void * src4,
- device const void * src5,
- device const void * src6,
- device float * dst,
- threadgroup float * shared [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort3 tpitg[[thread_position_in_threadgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgptg[[simdgroups_per_threadgroup]],
- uint3 tgpg[[threadgroups_per_grid]]) {
- constexpr short NW = N_SIMDWIDTH;
- shared[tpitg.x] = 0.0f;
- const int32_t i0 = tpitg.x;
- const int32_t i1 = tgpig.x;
- const int32_t ir = tgpig.y; // current head
- const int32_t i3 = tgpig.z; // current seq
- const int32_t nc = args.d_state;
- const int32_t nr = args.d_inner;
- const int32_t nh = args.n_head;
- const int32_t ng = args.n_group;
- const int32_t n_t = args.n_seq_tokens;
- const int32_t s_off = args.s_off;
- device const int32_t * ids = (device const int32_t *) src6;
- device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
- device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
- const int32_t i = i0 + i1*nc;
- const int32_t g = ir / (nh / ng); // repeat_interleave
- float s0 = s0_buff[i];
- float s = 0.0f;
- device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {ne30, nh}
- const float A0 = A[i0%args.ne30];
- device const float * x = (device const float *)((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i3*args.nb13); // {dim, nh, nt, ns}
- device const float * dt = (device const float *)((device const char *) src2 + ir*args.nb20 + i3*args.nb22); // {nh, nt, ns}
- device const float * B = (device const float *)((device const char *) src4 + g*args.nb41 + i3*args.nb43); // {d_state, ng, nt, ns}
- device const float * C = (device const float *)((device const char *) src5 + g*args.nb51 + i3*args.nb53); // {d_state, ng, nt, ns}
- device float * y = dst + (i1 + ir*(nr) + i3*(n_t*nh*nr)); // {dim, nh, nt, ns}
- for (int i2 = 0; i2 < n_t; i2 += sgptg) {
- threadgroup_barrier(mem_flags::mem_threadgroup);
- for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
- const float dt0 = dt[0];
- const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
- const float x_dt = x[0] * dtsp;
- const float dA = exp(dtsp * A0);
- s = (s0 * dA) + (B[i0] * x_dt);
- const float sumf = simd_sum(s * C[i0]);
- if (tiisg == 0) {
- shared[t*NW + sgitg] = sumf;
- }
- // recurse
- s0 = s;
- x += args.ns12;
- dt += args.ns21;
- B += args.ns42;
- C += args.ns52;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- const float sumf = simd_sum(shared[sgitg*NW + tiisg]);
- if (tiisg == 0 && i2 + sgitg < n_t) {
- y[sgitg*nh*nr] = sumf;
- }
- y += sgptg*nh*nr;
- }
- s_buff[i] = s;
- }
- kernel void kernel_rwkv_wkv6_f32(
- device const float * k,
- device const float * v,
- device const float * r,
- device const float * tf,
- device const float * td,
- device const float * state_in,
- device float * dst,
- constant uint & B,
- constant uint & T,
- constant uint & C,
- constant uint & H,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const uint head_size = 64; // TODO: support head_size = 128
- const uint batch_id = tgpig.x / H;
- const uint head_id = tgpig.x % H;
- const uint tid = tpitg.x;
- if (batch_id >= B || head_id >= H) {
- return;
- }
- const uint state_size = C * head_size;
- const uint n_seq_tokens = T / B;
- threadgroup float _k[head_size];
- threadgroup float _r[head_size];
- threadgroup float _tf[head_size];
- threadgroup float _td[head_size];
- float state[head_size];
- for (uint i = 0; i < head_size; i++) {
- state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
- + i * head_size + tid];
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- _tf[tid] = tf[head_id * head_size + tid];
- threadgroup_barrier(mem_flags::mem_threadgroup);
- const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
- const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
- for (uint t = start_t; t < end_t; t += C) {
- threadgroup_barrier(mem_flags::mem_threadgroup);
- _k[tid] = k[t];
- _r[tid] = r[t];
- _td[tid] = td[t];
- threadgroup_barrier(mem_flags::mem_threadgroup);
- const float v_val = v[t];
- float y = 0.0;
- for (uint j = 0; j < head_size; j += 4) {
- float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
- float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
- float4 tf_vec = float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
- float4 td_vec = float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
- float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
- float4 kv = k_vec * v_val;
- float4 temp = tf_vec * kv + s_vec;
- y += dot(r_vec, temp);
- s_vec = s_vec * td_vec + kv;
- state[j] = s_vec[0];
- state[j+1] = s_vec[1];
- state[j+2] = s_vec[2];
- state[j+3] = s_vec[3];
- }
- dst[t] = y;
- }
- for (uint i = 0; i < head_size; i++) {
- dst[T * C + batch_id * state_size + head_id * head_size * head_size
- + i * head_size + tid] = state[i];
- }
- }
- kernel void kernel_rwkv_wkv7_f32(
- device const float * r,
- device const float * w,
- device const float * k,
- device const float * v,
- device const float * a,
- device const float * b,
- device const float * state_in,
- device float * dst,
- constant uint & B,
- constant uint & T,
- constant uint & C,
- constant uint & H,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const uint head_size = 64; // TODO: support head_size = 128
- const uint batch_id = tgpig.x / H;
- const uint head_id = tgpig.x % H;
- const uint tid = tpitg.x;
- if (batch_id >= B || head_id >= H) {
- return;
- }
- const uint state_size = C * head_size;
- const uint n_seq_tokens = T / B;
- threadgroup float _r[head_size];
- threadgroup float _w[head_size];
- threadgroup float _k[head_size];
- threadgroup float _a[head_size];
- threadgroup float _b[head_size];
- float state[head_size];
- for (uint i = 0; i < head_size; i++) {
- state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
- + tid * head_size + i];
- }
- const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
- const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
- for (uint t = start_t; t < end_t; t += C) {
- threadgroup_barrier(mem_flags::mem_threadgroup);
- _r[tid] = r[t];
- _w[tid] = w[t];
- _k[tid] = k[t];
- _a[tid] = a[t];
- _b[tid] = b[t];
- threadgroup_barrier(mem_flags::mem_threadgroup);
- const float v_val = v[t];
- float y = 0.0, sa = 0.0;
- float4 sa_vec(0.0);
- for (uint j = 0; j < head_size; j += 4) {
- float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
- float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
- sa_vec += a_vec * s_vec;
- }
- sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3];
- for (uint j = 0; j < head_size; j += 4) {
- float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
- float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
- float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
- float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
- float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
- float4 kv = k_vec * v_val;
- s_vec = s_vec * w_vec + kv + sa * b_vec;
- y += dot(s_vec, r_vec);
- state[j] = s_vec[0];
- state[j+1] = s_vec[1];
- state[j+2] = s_vec[2];
- state[j+3] = s_vec[3];
- }
- dst[t] = y;
- }
- for (uint i = 0; i < head_size; i++) {
- dst[T * C + batch_id * state_size + head_id * head_size * head_size
- + tid * head_size + i] = state[i];
- }
- }
- kernel void kernel_argmax_f32(
- constant ggml_metal_kargs_argmax & args,
- device const char * src0,
- device char * dst,
- threadgroup char * shmem [[threadgroup(0)]],
- uint tgpig[[threadgroup_position_in_grid]],
- uint tpitg[[thread_position_in_threadgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint ntg[[threads_per_threadgroup]]) {
- device const float * x_row = (device const float *) ((device const char *) src0 + tgpig * args.nb01);
- float lmax = -INFINITY;
- int32_t larg = -1;
- for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
- if (x_row[i00] > lmax) {
- lmax = x_row[i00];
- larg = i00;
- }
- }
- // find the argmax value in the block
- float max_val = simd_max(lmax);
- int32_t arg_val = simd_max(select(-1, larg, lmax == max_val));
- device int32_t * dst_i32 = (device int32_t *) dst;
- threadgroup float * shared_maxval = (threadgroup float *) shmem;
- threadgroup int32_t * shared_argmax = (threadgroup int32_t *) shmem + N_SIMDWIDTH;
- if (ntg > N_SIMDWIDTH) {
- if (sgitg == 0) {
- shared_maxval[tiisg] = -INFINITY;
- shared_argmax[tiisg] = -1;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (tiisg == 0) {
- shared_maxval[sgitg] = max_val;
- shared_argmax[sgitg] = arg_val;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- max_val = shared_maxval[tiisg];
- arg_val = shared_argmax[tiisg];
- float max_val_reduced = simd_max(max_val);
- int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced));
- dst_i32[tgpig] = arg_val_reduced;
- return;
- }
- dst_i32[tgpig] = arg_val;
- }
- // F == 1 : norm (no fuse)
- // F == 2 : norm + mul
- // F == 3 : norm + mul + add
- template <typename T, short F>
- kernel void kernel_norm_fuse_impl(
- constant ggml_metal_kargs_norm & args,
- device const char * src0,
- device const char * src1_0,
- device const char * src1_1,
- device char * dst,
- threadgroup float * shmem_f32 [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort3 tpitg[[thread_position_in_threadgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort3 ntg[[threads_per_threadgroup]]) {
- if (sgitg == 0) {
- shmem_f32[tiisg] = 0.0f;
- }
- const int i01 = tgpig.x;
- const int i02 = tgpig.y;
- const int i03 = tgpig.z;
- device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
- device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
- device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
- T sumft(0.0f);
- float sumf = 0.0f;
- for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
- sumft += x[i00];
- }
- sumf = dot(sumft, T(1.0f));
- sumf = simd_sum(sumf);
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (tiisg == 0) {
- shmem_f32[sgitg] = sumf;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- sumf = shmem_f32[tiisg];
- sumf = simd_sum(sumf);
- const float mean = sumf/args.ne00;
- device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
- sumf = 0.0f;
- for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
- y[i00] = x[i00] - mean;
- sumf += dot(y[i00], y[i00]);
- }
- sumf = simd_sum(sumf);
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (tiisg == 0) {
- shmem_f32[sgitg] = sumf;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- sumf = shmem_f32[tiisg];
- sumf = simd_sum(sumf);
- const float variance = sumf/args.ne00;
- const float scale = 1.0f/sqrt(variance + args.eps);
- for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
- if (F == 1) {
- y[i00] = (y[i00]*scale);
- }
- if (F == 2) {
- y[i00] = (y[i00]*scale)*f0[i00];
- }
- if (F == 3) {
- y[i00] = (y[i00]*scale)*f0[i00] + f1[i00];
- }
- }
- }
- typedef decltype(kernel_norm_fuse_impl<float4, 1>) kernel_norm_fuse_t;
- template [[host_name("kernel_norm_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 1>;
- template [[host_name("kernel_norm_mul_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 2>;
- template [[host_name("kernel_norm_mul_add_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 3>;
- template [[host_name("kernel_norm_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 1>;
- template [[host_name("kernel_norm_mul_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 2>;
- template [[host_name("kernel_norm_mul_add_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 3>;
- // F == 1 : rms_norm (no fuse)
- // F == 2 : rms_norm + mul
- // F == 3 : rms_norm + mul + add
- template <typename T, short F>
- kernel void kernel_rms_norm_fuse_impl(
- constant ggml_metal_kargs_norm & args,
- device const char * src0,
- device const char * src1_0,
- device const char * src1_1,
- device char * dst,
- threadgroup float * shmem_f32 [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort3 tpitg[[thread_position_in_threadgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort3 ntg[[threads_per_threadgroup]]) {
- if (sgitg == 0) {
- shmem_f32[tiisg] = 0.0f;
- }
- const int i01 = tgpig.x;
- const int i02 = tgpig.y;
- const int i03 = tgpig.z;
- device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
- device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
- device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
- float sumf = 0.0f;
- // parallel sum
- for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
- sumf += dot(x[i00], x[i00]);
- }
- sumf = simd_sum(sumf);
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (tiisg == 0) {
- shmem_f32[sgitg] = sumf;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- sumf = shmem_f32[tiisg];
- sumf = simd_sum(sumf);
- const float mean = sumf/args.ne00;
- const float scale = 1.0f/sqrt(mean + args.eps);
- device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
- for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
- if (F == 1) {
- y[i00] = (x[i00]*scale);
- }
- if (F == 2) {
- y[i00] = (x[i00]*scale)*f0[i00];
- }
- if (F == 3) {
- y[i00] = (x[i00]*scale)*f0[i00] + f1[i00];
- }
- }
- }
- typedef decltype(kernel_rms_norm_fuse_impl<float4, 1>) kernel_rms_norm_fuse_t;
- template [[host_name("kernel_rms_norm_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 1>;
- template [[host_name("kernel_rms_norm_mul_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 2>;
- template [[host_name("kernel_rms_norm_mul_add_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 3>;
- template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 1>;
- template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
- template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;
- kernel void kernel_l2_norm_f32(
- constant ggml_metal_kargs_l2_norm & args,
- device const char * src0,
- device char * dst,
- threadgroup float * shmem_f32 [[threadgroup(0)]],
- uint tgpig[[threadgroup_position_in_grid]],
- ushort tpitg[[thread_position_in_threadgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort ntg[[threads_per_threadgroup]]) {
- if (sgitg == 0) {
- shmem_f32[tiisg] = 0.0f;
- }
- device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
- float sumf = 0.0f;
- // parallel sum
- for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
- sumf += dot(x[i00], x[i00]);
- }
- sumf = simd_sum(sumf);
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (tiisg == 0) {
- shmem_f32[sgitg] = sumf;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- sumf = shmem_f32[tiisg];
- sumf = simd_sum(sumf);
- const float scale = 1.0f/sqrt(max(sumf, args.eps));
- device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
- for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
- y[i00] = x[i00] * scale;
- }
- }
- kernel void kernel_group_norm_f32(
- constant ggml_metal_kargs_group_norm & args,
- device const float * src0,
- device float * dst,
- threadgroup float * buf [[threadgroup(0)]],
- uint tgpig[[threadgroup_position_in_grid]],
- uint tpitg[[thread_position_in_threadgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint ntg[[threads_per_threadgroup]]) {
- const int64_t ne = args.ne00*args.ne01*args.ne02;
- const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.ngrp - 1) / args.ngrp);
- int start = tgpig * gs;
- int end = start + gs;
- start += tpitg;
- if (end >= ne) {
- end = ne;
- }
- float tmp = 0.0f; // partial sum for thread in warp
- for (int j = start; j < end; j += ntg) {
- tmp += src0[j];
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- tmp = simd_sum(tmp);
- if (ntg > N_SIMDWIDTH) {
- if (sgitg == 0) {
- buf[tiisg] = 0.0f;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (tiisg == 0) {
- buf[sgitg] = tmp;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- tmp = buf[tiisg];
- tmp = simd_sum(tmp);
- }
- const float mean = tmp / gs;
- tmp = 0.0f;
- for (int j = start; j < end; j += ntg) {
- float xi = src0[j] - mean;
- dst[j] = xi;
- tmp += xi * xi;
- }
- tmp = simd_sum(tmp);
- if (ntg > N_SIMDWIDTH) {
- if (sgitg == 0) {
- buf[tiisg] = 0.0f;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (tiisg == 0) {
- buf[sgitg] = tmp;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- tmp = buf[tiisg];
- tmp = simd_sum(tmp);
- }
- const float variance = tmp / gs;
- const float scale = 1.0f/sqrt(variance + args.eps);
- for (int j = start; j < end; j += ntg) {
- dst[j] *= scale;
- }
- }
- // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
- // il indicates where the q4 quants begin (0 or QK4_0/4)
- // we assume that the yl's have been multiplied with the appropriate scale factor
- // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
- inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
- float d = qb_curr->d;
- float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
- device const uint16_t * qs = ((device const uint16_t *) qb_curr + 1 + il/2);
- for (int i = 0; i < 8; i += 2) {
- acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
- acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
- acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
- acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
- }
- return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]);
- }
- // function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
- // il indicates where the q4 quants begin (0 or QK4_0/4)
- // we assume that the yl's have been multiplied with the appropriate scale factor
- // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
- inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
- float d = qb_curr->d;
- float m = qb_curr->m;
- float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
- device const uint16_t * qs = ((device const uint16_t *) qb_curr + 2 + il/2);
- for (int i = 0; i < 8; i+=2) {
- acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
- acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
- acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
- acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
- }
- return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
- }
- // function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
- // il indicates where the q5 quants begin (0 or QK5_0/4)
- // we assume that the yl's have been multiplied with the appropriate scale factor
- // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
- inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
- float d = qb_curr->d;
- float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
- device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
- const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
- for (int i = 0; i < 8; i+=2) {
- acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010));
- acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
- acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
- acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
- }
- return d * (sumy * -16.f + acc[0] + acc[1] + acc[2] + acc[3]);
- }
- // function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
- // il indicates where the q5 quants begin (0 or QK5_1/4)
- // we assume that the yl's have been multiplied with the appropriate scale factor
- // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
- inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
- float d = qb_curr->d;
- float m = qb_curr->m;
- float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
- device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
- const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
- for (int i = 0; i < 8; i+=2) {
- acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010));
- acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
- acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
- acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
- }
- return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
- }
- template<short NR0>
- static inline void helper_mv_reduce_and_write(
- device float * dst_f32,
- float sumf[NR0],
- const int r0,
- const int ne01,
- ushort tiisg,
- ushort sgitg,
- threadgroup char * shmem) {
- constexpr short NW = N_SIMDWIDTH;
- threadgroup float * shmem_f32[NR0];
- for (short row = 0; row < NR0; ++row) {
- shmem_f32[row] = (threadgroup float *) shmem + NW*row;
- if (sgitg == 0) {
- shmem_f32[row][tiisg] = 0.0f;
- }
- sumf[row] = simd_sum(sumf[row]);
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- for (short row = 0; row < NR0; ++row) {
- if (tiisg == 0) {
- shmem_f32[row][sgitg] = sumf[row];
- }
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- for (short row = 0; row < NR0 && r0 + row < ne01; ++row) {
- float tot = simd_sum(shmem_f32[row][tiisg]);
- if (tiisg == 0 && sgitg == 0) {
- dst_f32[r0 + row] = tot;
- }
- }
- }
- constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]];
- constant short FC_mul_mv_nxpsg [[function_constant(FC_MUL_MV + 1)]];
- template<typename block_q_type, short NR0, typename args_t>
- void mul_vec_q_n_f32_impl(
- args_t args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem,
- uint3 tgpig,
- ushort tiisg,
- ushort sgitg) {
- const short NSG = FC_mul_mv_nsg;
- constexpr short NW = N_SIMDWIDTH;
- constexpr short NQ = 16;
- const int nb = args.ne00/QK4_0;
- const int r0 = (tgpig.x*NSG + sgitg)*NR0;
- //const int r0 = tgpig.x*NR0;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
- const uint i12 = im%args.ne12;
- const uint i13 = im/args.ne12;
- //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- //device const block_q_type * x = (device const block_q_type *) (src0 + offset0);
- device const float * y = (device const float *) (src1 + offset1);
- // pointers to src0 rows
- device const block_q_type * ax[NR0];
- FOR_UNROLL (int row = 0; row < NR0; ++row) {
- const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
- }
- float sumf[NR0] = {0.f};
- const short ix = (tiisg/(NW/NQ));
- const short il = (tiisg%(NW/NQ))*8;
- //const int ib0 = sgitg*NQ + ix;
- const int ib0 = ix;
- float yl[16]; // src1 vector cache
- //device const float * yb = y + ix*QK4_0 + il;
- device const float * yb = y + ib0*QK4_0 + il;
- // each thread in a SIMD group deals with half a block.
- //for (int ib = ib0; ib < nb; ib += NSG*NQ) {
- for (int ib = ib0; ib < nb; ib += NQ) {
- float sumy[2] = { 0.f, 0.f };
- FOR_UNROLL (short i = 0; i < 8; i += 2) {
- sumy[0] += yb[i + 0] + yb[i + 1];
- yl[i + 0] = yb[i + 0];
- yl[i + 1] = yb[i + 1]/256.f;
- sumy[1] += yb[i + 16] + yb[i + 17];
- yl[i + 8] = yb[i + 16]/16.f;
- yl[i + 9] = yb[i + 17]/4096.f;
- }
- FOR_UNROLL (short row = 0; row < NR0; row++) {
- sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
- }
- yb += QK4_0 * 16;
- //yb += NSG*NQ*QK4_0;
- }
- device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
- //helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
- for (int row = 0; row < NR0; ++row) {
- const float tot = simd_sum(sumf[row]);
- if (tiisg == 0 && r0 + row < args.ne01) {
- dst_f32[r0 + row] = tot;
- }
- }
- }
- kernel void kernel_mul_mv_q4_0_f32(
- constant ggml_metal_kargs_mul_mv & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
- }
- kernel void kernel_mul_mv_q4_1_f32(
- constant ggml_metal_kargs_mul_mv & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
- }
- kernel void kernel_mul_mv_q5_0_f32(
- constant ggml_metal_kargs_mul_mv & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
- }
- kernel void kernel_mul_mv_q5_1_f32(
- constant ggml_metal_kargs_mul_mv & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
- }
- template<short NR0, typename args_t>
- void kernel_mul_mv_q8_0_f32_impl(
- args_t args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem,
- uint3 tgpig,
- ushort tiisg,
- ushort sgitg) {
- const short NSG = FC_mul_mv_nsg;
- constexpr short NW = N_SIMDWIDTH;
- constexpr short NQ = 8;
- const int nb = args.ne00/QK8_0;
- const int r0 = tgpig.x*NR0;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
- const uint i12 = im%args.ne12;
- const uint i13 = im/args.ne12;
- //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0);
- device const float * y = (device const float *) (src1 + offset1);
- // pointers to src0 rows
- device const block_q8_0 * ax[NR0];
- FOR_UNROLL (short row = 0; row < NR0; ++row) {
- const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
- }
- float sumf[NR0] = { 0.f };
- const short ix = tiisg/(NW/NQ);
- const short il = tiisg%(NW/NQ);
- const int ib0 = sgitg*NQ + ix;
- float yl[NQ];
- device const float * yb = y + ib0*QK8_0 + il*NQ;
- // each thread in a SIMD group deals with NQ quants at a time
- for (int ib = ib0; ib < nb; ib += NSG*NQ) {
- for (short i = 0; i < NQ; ++i) {
- yl[i] = yb[i];
- }
- for (short row = 0; row < NR0; row++) {
- device const int8_t * qs = ax[row][ib].qs + il*NQ;
- float sumq = 0.f;
- FOR_UNROLL (short i = 0; i < NQ; ++i) {
- sumq += qs[i] * yl[i];
- }
- sumf[row] += sumq*ax[row][ib].d;
- }
- yb += NSG*NQ*QK8_0;
- }
- device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
- }
- [[host_name("kernel_mul_mv_q8_0_f32")]]
- kernel void kernel_mul_mv_q8_0_f32(
- constant ggml_metal_kargs_mul_mv & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
- }
- // mat-vec kernel processing in chunks of float4
- // chpb - chunks per quantization block
- template<short r1ptg, typename q_t, short chpb, void (*deq_t4)(device const q_t *, short, thread float4 &) >
- void kernel_mul_mv_ext_q4_f32_impl(
- constant ggml_metal_kargs_mul_mv_ext & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- const short NSG = FC_mul_mv_nsg;
- const short nxpsg = FC_mul_mv_nxpsg;
- const short chpt = 4; // chunks per thread
- //const short nxpsg = (32);
- const short nypsg = (32/nxpsg);
- const short tx = tiisg%nxpsg;
- const short ty = tiisg/nxpsg;
- const int i01 = tgpig.x*(nypsg*NSG) + nypsg*sgitg + ty;
- const int i11 = tgpig.y*r1ptg;
- const int i1m = tgpig.z;
- const int i12 = i1m%args.ne12;
- const int i13 = i1m/args.ne12;
- const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
- device const float4 * y4[r1ptg];
- for (int ir1 = 0; ir1 < r1ptg; ++ir1) {
- y4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4 *) src1;
- }
- float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };
- short cch = tx%chpb; // current chunk index
- for (int ich = tx; 4*ich < args.ne00; ich += chpt*nxpsg) {
- float4 lx[chpt];
- #pragma unroll(chpt)
- for (short ch = 0; ch < chpt; ++ch) {
- deq_t4(xq, cch, lx[ch]);
- cch += nxpsg;
- if (cch >= chpb) {
- xq += cch/chpb;
- cch %= chpb;
- }
- }
- #pragma unroll(chpt)
- for (short ch = 0; ch < chpt; ++ch) {
- #pragma unroll(r1ptg)
- for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
- sumf[ir1] += dot(lx[ch], y4[ir1][ch*nxpsg]);
- }
- }
- #pragma unroll(r1ptg)
- for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
- y4[ir1] += chpt*nxpsg;
- }
- }
- // reduce only the threads in each row
- for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
- if (nxpsg >= 32) {
- sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);
- }
- if (nxpsg >= 16) {
- sumf[ir1] += simd_shuffle_down(sumf[ir1], 8);
- }
- if (nxpsg >= 8) {
- sumf[ir1] += simd_shuffle_down(sumf[ir1], 4);
- }
- if (nxpsg >= 4) {
- sumf[ir1] += simd_shuffle_down(sumf[ir1], 2);
- }
- if (nxpsg >= 2) {
- sumf[ir1] += simd_shuffle_down(sumf[ir1], 1);
- }
- //sumf[ir1] = simd_sum(sumf[ir1]);
- }
- if (tx == 0) {
- for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) {
- device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0;
- if (i01 < args.ne01) {
- dst_f32[i01] = sumf[ir1];
- }
- }
- }
- }
- // mat-vec kernel processing in chunks of float4x4
- template<short r1ptg, typename q_t, short chpb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &) >
- void kernel_mul_mv_ext_q4x4_f32_impl(
- constant ggml_metal_kargs_mul_mv_ext & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- const short NSG = FC_mul_mv_nsg;
- const short nxpsg = FC_mul_mv_nxpsg;
- const short chpt = 1;
- //const short nxpsg = (32);
- const short nypsg = (32/nxpsg);
- const short tx = tiisg%nxpsg;
- const short ty = tiisg/nxpsg;
- const int i01 = tgpig.x*(nypsg*NSG) + nypsg*sgitg + ty;
- const int i11 = tgpig.y*r1ptg;
- const int i1m = tgpig.z;
- const int i12 = i1m%args.ne12;
- const int i13 = i1m/args.ne12;
- const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
- device const float4x4 * y4x4[r1ptg];
- for (int ir1 = 0; ir1 < r1ptg; ++ir1) {
- y4x4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4x4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4x4 *) src1;
- }
- float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };
- short cch = tx%chpb;
- for (int ich = tx; 16*ich < args.ne00; ich += chpt*nxpsg) {
- float4x4 lx[chpt];
- #pragma unroll(chpt)
- for (short ch = 0; ch < chpt; ++ch) {
- deq_t4x4(xq, cch, lx[ch]);
- cch += nxpsg;
- if (cch >= chpb) {
- xq += cch/chpb;
- cch %= chpb;
- }
- }
- #pragma unroll(chpt)
- for (short ch = 0; ch < chpt; ++ch) {
- #pragma unroll(r1ptg)
- for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
- sumf[ir1] +=
- dot(lx[ch][0], y4x4[ir1][ch*nxpsg][0]) +
- dot(lx[ch][1], y4x4[ir1][ch*nxpsg][1]) +
- dot(lx[ch][2], y4x4[ir1][ch*nxpsg][2]) +
- dot(lx[ch][3], y4x4[ir1][ch*nxpsg][3]);
- }
- }
- #pragma unroll(r1ptg)
- for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
- y4x4[ir1] += chpt*nxpsg;
- }
- }
- for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
- if (nxpsg >= 32) {
- sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);
- }
- if (nxpsg >= 16) {
- sumf[ir1] += simd_shuffle_down(sumf[ir1], 8);
- }
- if (nxpsg >= 8) {
- sumf[ir1] += simd_shuffle_down(sumf[ir1], 4);
- }
- if (nxpsg >= 4) {
- sumf[ir1] += simd_shuffle_down(sumf[ir1], 2);
- }
- if (nxpsg >= 2) {
- sumf[ir1] += simd_shuffle_down(sumf[ir1], 1);
- }
- //sumf[ir1] = simd_sum(sumf[ir1]);
- }
- if (tx == 0) {
- for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) {
- device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0;
- if (i01 < args.ne01) {
- dst_f32[i01] = sumf[ir1];
- }
- }
- }
- }
- // dispatchers needed for compile-time nxpsg
- // epb - elements per quantization block
- template<short r1ptg, typename q_t, short epb, void (*deq_t4)(device const q_t *, short, thread float4 &)>
- kernel void kernel_mul_mv_ext_q4_f32_disp(
- constant ggml_metal_kargs_mul_mv_ext & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_ext_q4_f32_impl<r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg);
- }
- template<short r1ptg, typename q_t, short epb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &)>
- kernel void kernel_mul_mv_ext_q4x4_f32_disp(
- constant ggml_metal_kargs_mul_mv_ext & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_ext_q4x4_f32_impl<r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg);
- }
- typedef decltype(kernel_mul_mv_ext_q4_f32_disp <2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t;
- typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>) mul_mv_ext_q4x4_f32_t;
- template [[host_name("kernel_mul_mv_ext_f32_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, float4, 4, dequantize_f32_t4>;
- template [[host_name("kernel_mul_mv_ext_f32_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, float4, 4, dequantize_f32_t4>;
- template [[host_name("kernel_mul_mv_ext_f32_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, float4, 4, dequantize_f32_t4>;
- template [[host_name("kernel_mul_mv_ext_f32_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, float4, 4, dequantize_f32_t4>;
- template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, half4, 4, dequantize_f16_t4>;
- template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, half4, 4, dequantize_f16_t4>;
- template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>;
- template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, half4, 4, dequantize_f16_t4>;
- template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>;
- template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>;
- template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>;
- template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_0, 32, dequantize_q4_0_t4>;
- template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_1, 32, dequantize_q4_1_t4>;
- template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_1, 32, dequantize_q4_1_t4>;
- template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_1, 32, dequantize_q4_1_t4>;
- template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_1, 32, dequantize_q4_1_t4>;
- template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_0, 32, dequantize_q5_0_t4>;
- template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_0, 32, dequantize_q5_0_t4>;
- template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_0, 32, dequantize_q5_0_t4>;
- template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_0, 32, dequantize_q5_0_t4>;
- template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_1, 32, dequantize_q5_1_t4>;
- template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_1, 32, dequantize_q5_1_t4>;
- template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_1, 32, dequantize_q5_1_t4>;
- template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_1, 32, dequantize_q5_1_t4>;
- template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q8_0, 32, dequantize_q8_0_t4>;
- template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q8_0, 32, dequantize_q8_0_t4>;
- template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0, 32, dequantize_q8_0_t4>;
- template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0, 32, dequantize_q8_0_t4>;
- template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_mxfp4, 32, dequantize_mxfp4_t4>;
- template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_mxfp4, 32, dequantize_mxfp4_t4>;
- template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_mxfp4, 32, dequantize_mxfp4_t4>;
- template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_mxfp4, 32, dequantize_mxfp4_t4>;
- template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
- template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
- template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
- template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
- template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>;
- template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q4_K, 256, dequantize_q4_K>;
- template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q4_K, 256, dequantize_q4_K>;
- template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q4_K, 256, dequantize_q4_K>;
- template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q5_K, 256, dequantize_q5_K>;
- template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q5_K, 256, dequantize_q5_K>;
- template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q5_K, 256, dequantize_q5_K>;
- template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q5_K, 256, dequantize_q5_K>;
- template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q6_K, 256, dequantize_q6_K>;
- template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q6_K, 256, dequantize_q6_K>;
- template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>;
- template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>;
- template<typename T0, typename T1, short NR0, typename args_t>
- void kernel_mul_mv_t_t_impl(
- args_t args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem,
- uint3 tgpig,
- ushort tiisg,
- ushort sgitg) {
- const short NSG = FC_mul_mv_nsg;
- constexpr short NW = N_SIMDWIDTH;
- constexpr short NB = 32;
- constexpr short NF = 8;
- const int nb = args.ne00/NB;
- const int r0 = tgpig.x*NR0;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
- const uint i12 = im%args.ne12;
- const uint i13 = im/args.ne12;
- //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- //device const T0 * x = (device const T0 *) (src0 + offset0);
- device const T1 * y = (device const T1 *) (src1 + offset1);
- // pointers to src0 rows
- device const T0 * ax [NR0];
- FOR_UNROLL (short row = 0; row < NR0; ++row) {
- const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- ax[row] = (device const T0 *) ((device char *) src0 + offset0);
- }
- float sumf[NR0] = { 0.f };
- const short ix = tiisg/(NW/NF);
- const short il = tiisg%(NW/NF);
- const int ib0 = sgitg*NF + ix;
- T1 yl[NF];
- device const T1 * yb = y + (ib0*NB + il*NF);
- for (int ib = ib0; ib < nb; ib += NSG*NF) {
- for (short i = 0; i < NF; ++i) {
- yl[i] = yb[i];
- }
- for (short row = 0; row < NR0; row++) {
- device const T0 * xb = ax[row] + (ib*NB + il*NF);
- float sumq = 0.f;
- FOR_UNROLL (short i = 0; i < NF; ++i) {
- sumq += xb[i] * yl[i];
- }
- sumf[row] += sumq;
- }
- yb += NSG*NF*NW;
- }
- for (int i = nb*NB + sgitg*NW + tiisg; i < args.ne00; i += NW*NSG) {
- for (short row = 0; row < NR0; row++) {
- sumf[row] += ax[row][i] * y[i];
- }
- }
- device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
- }
- template<typename T0, typename T1, typename args_t>
- void kernel_mul_mv_t_t_disp(
- args_t args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem,
- uint3 tgpig,
- ushort tiisg,
- ushort sgitg) {
- switch (args.nr0) {
- //case 1: kernel_mul_mv_t_t_impl<T0, T1, 1, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
- case 2: kernel_mul_mv_t_t_impl<T0, T1, 2, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
- //case 3: kernel_mul_mv_t_t_impl<T0, T1, 3, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
- //case 4: kernel_mul_mv_t_t_impl<T0, T1, 4, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
- }
- }
- template<typename T0, typename T1>
- kernel void kernel_mul_mv_t_t(
- constant ggml_metal_kargs_mul_mv & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_t_t_disp<T0, T1, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
- }
- typedef decltype(kernel_mul_mv_t_t<half, half>) mul_mv_t_t;
- template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<float, float>;
- template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, float>;
- template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, half>;
- #if defined(GGML_METAL_HAS_BF16)
- template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, float>;
- template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, bfloat>;
- #endif
- template<typename T0, typename T04, typename T1, typename T14, short NR0, typename args_t>
- void kernel_mul_mv_t_t_4_impl(
- args_t args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem,
- uint3 tgpig,
- ushort tiisg,
- ushort sgitg) {
- const short NSG = FC_mul_mv_nsg;
- constexpr short NW = N_SIMDWIDTH;
- constexpr short NB = 32;
- constexpr short NF = 16;
- constexpr short NF4 = NF/4;
- const int nb = args.ne00/NB;
- const int r0 = tgpig.x*NR0;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
- const uint i12 = im%args.ne12;
- const uint i13 = im/args.ne12;
- //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const T1 * y = (device const T1 *) (src1 + offset1);
- device const T14 * y4 = (device const T14 *) (src1 + offset1);
- // pointers to src0 rows
- device const T0 * ax [NR0];
- device const T04 * ax4[NR0];
- FOR_UNROLL (short row = 0; row < NR0; ++row) {
- const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- ax [row] = (device const T0 *) ((device char *) src0 + offset0);
- ax4[row] = (device const T04 *) ((device char *) src0 + offset0);
- }
- float sumf[NR0] = { 0.f };
- const short ix = tiisg/(NW/NF);
- const short il = tiisg%(NW/NF);
- const int ib0 = sgitg*NF + ix;
- T14 yl4[NF4];
- device const T14 * yb4 = y4 + (ib0*NB + il*NF)/4;
- for (int ib = ib0; ib < nb; ib += NSG*NF) {
- for (short i = 0; i < NF4; ++i) {
- yl4[i] = yb4[i];
- }
- for (short row = 0; row < NR0; row++) {
- device const T04 * xb4 = ax4[row] + (ib*NB + il*NF)/4;
- float sumq = 0.f;
- FOR_UNROLL (short i = 0; i < NF4; ++i) {
- sumq += dot(float4(xb4[i]), float4(yl4[i]));
- }
- sumf[row] += sumq;
- }
- yb4 += NSG*NF*NW/4;
- }
- for (int i = nb*NB + sgitg*NW + tiisg; i < args.ne00; i += NW*NSG) {
- for (short row = 0; row < NR0; row++) {
- sumf[row] += ax[row][i] * y[i];
- }
- }
- device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
- }
- template<typename T0, typename T04, typename T1, typename T14, typename args_t>
- void kernel_mul_mv_t_t_4_disp(
- args_t args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem,
- uint3 tgpig,
- ushort tiisg,
- ushort sgitg) {
- switch (args.nr0) {
- //case 1: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 1, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
- case 2: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 2, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
- //case 3: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 3, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
- //case 4: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 4, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
- };
- }
- template<typename T0, typename T04, typename T1, typename T14>
- kernel void kernel_mul_mv_t_t_4(
- constant ggml_metal_kargs_mul_mv & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_t_t_4_disp<T0, T04, T1, T14, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
- }
- typedef decltype(kernel_mul_mv_t_t_4<half, half4, half, half4>) mul_mv_t_t_4;
- template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<float, float4, float, float4>;
- template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, float, float4>;
- template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, half, half4>;
- #if defined(GGML_METAL_HAS_BF16)
- template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, float, float4>;
- template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, bfloat, bfloat4>;
- #endif
- template<typename T0, typename T1, typename args_t>
- void kernel_mul_mv_t_t_short_impl(
- args_t args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint3 tgpig,
- ushort tiisg) {
- const int r0 = tgpig.x*32 + tiisg;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
- if (r0 >= args.ne01) {
- return;
- }
- const uint i12 = im%args.ne12;
- const uint i13 = im/args.ne12;
- const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- device const T0 * x = (device const T0 *) (src0 + offset0);
- device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
- const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const T1 * y = (device const T1 *) (src1 + offset1);
- float res = 0.0f;
- for (int i = 0; i < args.ne00; ++i) {
- res += (float) x[i] * (float) y[i];
- }
- dst_f32[(uint64_t)r1*args.ne0 + r0] = res;
- }
- template<typename T0, typename T1>
- kernel void kernel_mul_mv_t_t_short(
- constant ggml_metal_kargs_mul_mv & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]]) {
- kernel_mul_mv_t_t_short_impl<T0, T1, constant ggml_metal_kargs_mul_mv &>(
- args,
- src0,
- src1,
- dst,
- tgpig,
- tiisg);
- }
- typedef decltype(kernel_mul_mv_t_t_short<half, half>) mul_mv_t_t_short_t;
- template [[host_name("kernel_mul_mv_f32_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<float, float>;
- template [[host_name("kernel_mul_mv_f16_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<half, float>;
- template [[host_name("kernel_mul_mv_f16_f16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<half, half>;
- #if defined(GGML_METAL_HAS_BF16)
- template [[host_name("kernel_mul_mv_bf16_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, float>;
- template [[host_name("kernel_mul_mv_bf16_bf16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, bfloat>;
- #endif
- constant bool FC_rope_is_imrope [[function_constant(FC_ROPE + 0)]];
- static float rope_yarn_ramp(const float low, const float high, const int i0) {
- const float y = (i0 / 2 - low) / max(0.001f, high - low);
- return 1.0f - min(1.0f, max(0.0f, y));
- }
- // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
- // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
- static void rope_yarn(
- float theta_extrap, float freq_scale, float corr_dims[2], int i0, float ext_factor, float mscale,
- thread float * cos_theta, thread float * sin_theta) {
- // Get n-d rotational scaling corrected for extrapolation
- float theta_interp = freq_scale * theta_extrap;
- float theta = theta_interp;
- if (ext_factor != 0.0f) {
- float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
- theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
- // Get n-d magnitude scaling corrected for interpolation
- mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
- }
- *cos_theta = cos(theta) * mscale;
- *sin_theta = sin(theta) * mscale;
- }
- // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
- // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
- static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
- return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
- }
- static void rope_yarn_corr_dims(
- int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
- ) {
- // start and end correction dims
- dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
- dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
- }
- template<typename T>
- kernel void kernel_rope_norm(
- constant ggml_metal_kargs_rope & args,
- device const char * src0,
- device const char * src1,
- device const char * src2,
- device char * dst,
- ushort tiitg[[thread_index_in_threadgroup]],
- ushort3 tptg [[threads_per_threadgroup]],
- uint3 tgpig[[threadgroup_position_in_grid]]) {
- const int i3 = tgpig[2];
- const int i2 = tgpig[1];
- const int i1 = tgpig[0];
- float corr_dims[2];
- rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
- device const int32_t * pos = (device const int32_t *) src1;
- const float theta_base = (float) pos[i2];
- const float inv_ndims = -1.f/args.n_dims;
- float cos_theta;
- float sin_theta;
- for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
- if (i0 < args.n_dims) {
- const int ic = i0/2;
- const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
- const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
- rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
- device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
- device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
- const float x0 = src[0];
- const float x1 = src[1];
- dst_data[0] = x0*cos_theta - x1*sin_theta;
- dst_data[1] = x0*sin_theta + x1*cos_theta;
- } else {
- device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
- device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
- dst_data[0] = src[0];
- dst_data[1] = src[1];
- }
- }
- }
- template<typename T>
- kernel void kernel_rope_neox(
- constant ggml_metal_kargs_rope & args,
- device const char * src0,
- device const char * src1,
- device const char * src2,
- device char * dst,
- ushort tiitg[[thread_index_in_threadgroup]],
- ushort3 tptg [[threads_per_threadgroup]],
- uint3 tgpig[[threadgroup_position_in_grid]]) {
- const int i3 = tgpig[2];
- const int i2 = tgpig[1];
- const int i1 = tgpig[0];
- float corr_dims[2];
- rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
- device const int32_t * pos = (device const int32_t *) src1;
- const float theta_base = (float) pos[i2];
- const float inv_ndims = -1.f/args.n_dims;
- float cos_theta;
- float sin_theta;
- for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
- if (i0 < args.n_dims) {
- const int ic = i0/2;
- const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
- const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
- rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
- device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
- device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
- const float x0 = src[0];
- const float x1 = src[args.n_dims/2];
- dst_data[0] = x0*cos_theta - x1*sin_theta;
- dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
- } else {
- device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
- device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
- dst_data[0] = src[0];
- dst_data[1] = src[1];
- }
- }
- }
- template<typename T>
- kernel void kernel_rope_multi(
- constant ggml_metal_kargs_rope & args,
- device const char * src0,
- device const char * src1,
- device const char * src2,
- device char * dst,
- ushort tiitg[[thread_index_in_threadgroup]],
- ushort3 tptg [[threads_per_threadgroup]],
- uint3 tgpig[[threadgroup_position_in_grid]]) {
- const int i3 = tgpig[2];
- const int i2 = tgpig[1];
- const int i1 = tgpig[0];
- float corr_dims[2];
- rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
- device const int32_t * pos = (device const int32_t *) src1;
- const float inv_ndims = -1.f/args.n_dims;
- float cos_theta;
- float sin_theta;
- for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
- if (i0 < args.n_dims) {
- const int ic = i0/2;
- // mrope theta calculations
- // note: the rest is the same as kernel_rope_neox
- const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3;
- const int sec_w01 = args.sect_0 + args.sect_1; // end of section 1
- const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
- const int sector = ic % sect_dims;
- float theta_base;
- if (FC_rope_is_imrope) {
- if (sector % 3 == 1 && sector < 3 * args.sect_1) { // h
- theta_base = (float) pos[i2 + args.ne02 * 1];
- } else if (sector % 3 == 2 && sector < 3 * args.sect_2) { // w
- theta_base = (float) pos[i2 + args.ne02 * 2];
- } else if (sector % 3 == 0 && sector < 3 * args.sect_0) { // t
- theta_base = (float) pos[i2 + args.ne02 * 0];
- } else { // e
- theta_base = (float) pos[i2 + args.ne02 * 3];
- }
- } else {
- if (sector < args.sect_0) {
- theta_base = (float) pos[i2];
- } else if (sector < sec_w01) {
- theta_base = (float) pos[i2 + args.ne02 * 1];
- } else if (sector < sec_w012) {
- theta_base = (float) pos[i2 + args.ne02 * 2];
- } else {
- theta_base = (float) pos[i2 + args.ne02 * 3];
- }
- }
- // end of mrope
- const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
- const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
- rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
- device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
- device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
- const float x0 = src[0];
- const float x1 = src[args.n_dims/2];
- dst_data[0] = x0*cos_theta - x1*sin_theta;
- dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
- } else {
- device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
- device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
- dst_data[0] = src[0];
- dst_data[1] = src[1];
- }
- }
- }
- template<typename T>
- kernel void kernel_rope_vision(
- constant ggml_metal_kargs_rope & args,
- device const char * src0,
- device const char * src1,
- device const char * src2,
- device char * dst,
- ushort tiitg[[thread_index_in_threadgroup]],
- ushort3 tptg [[threads_per_threadgroup]],
- uint3 tgpig[[threadgroup_position_in_grid]]) {
- const int i3 = tgpig[2];
- const int i2 = tgpig[1];
- const int i1 = tgpig[0];
- float corr_dims[2];
- rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
- device const int32_t * pos = (device const int32_t *) src1;
- const float inv_ndims = -1.f/args.n_dims;
- float cos_theta;
- float sin_theta;
- for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
- if (i0 < 2*args.n_dims) { // different from kernel_rope_multi
- const int ic = i0/2;
- // mrope theta calculations (only support 2 dimensions)
- const int sect_dims = args.sect_0 + args.sect_1;
- const int sector = ic % sect_dims;
- float p;
- float theta_base;
- if (sector < args.sect_1) {
- p = (float) sector;
- theta_base = (float) pos[i2];
- } else {
- p = (float) sector - args.sect_0;
- theta_base = (float) pos[i2 + args.ne02];
- }
- const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
- // end of mrope
- const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
- rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
- device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
- device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
- const float x0 = src[0];
- const float x1 = src[args.n_dims]; // different from kernel_rope_multi
- dst_data[0] = x0*cos_theta - x1*sin_theta;
- dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi
- } else {
- device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
- device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
- dst_data[0] = src[0];
- dst_data[1] = src[1];
- }
- }
- }
- typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
- typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
- typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
- typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
- template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
- template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
- template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
- template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
- template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
- template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
- template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
- template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
- typedef void (im2col_t)(
- constant ggml_metal_kargs_im2col & args,
- device const float * x,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tgpg[[threadgroups_per_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]);
- template <typename T>
- kernel void kernel_im2col(
- constant ggml_metal_kargs_im2col & args,
- device const float * x,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tgpg[[threadgroups_per_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- // const int64_t IC = tgpg[0];
- const int64_t OH = tgpg[1];
- const int64_t OW = tgpg[2];
- const int64_t KH = ntg[1];
- const int64_t KW = ntg[2];
- int64_t in = tpitg[0];
- const int64_t ikh = tpitg[1];
- const int64_t ikw = tpitg[2];
- const int64_t iic = tgpig[0];
- const int64_t ioh = tgpig[1];
- const int64_t iow = tgpig[2];
- const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0;
- const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1;
- int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
- device T * pdst = (device T *) (dst);
- if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
- while (in < args.N) {
- pdst[offset_dst] = 0.0f;
- offset_dst += ntg[0]*args.CHW*OH*OW;
- in += ntg[0];
- }
- } else {
- int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;
- while (in < args.N) {
- pdst[offset_dst] = x[offset_src];
- offset_dst += ntg[0]*args.CHW*OH*OW;
- offset_src += ntg[0]*args.ofs0;
- in += ntg[0];
- }
- }
- }
- template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
- template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
- // TODO: obolete -- remove
- //typedef void (im2col_ext_t)(
- // constant ggml_metal_kargs_im2col & args,
- // device const float * x,
- // device char * dst,
- // uint3 tgpig[[threadgroup_position_in_grid]],
- // uint3 tgpg[[threadgroups_per_grid]],
- // uint3 tpitg[[thread_position_in_threadgroup]],
- // uint3 ntg[[threads_per_threadgroup]]);
- //
- //template <typename T>
- //kernel void kernel_im2col_ext(
- // constant ggml_metal_kargs_im2col & args,
- // device const float * x,
- // device char * dst,
- // uint3 tgpig[[threadgroup_position_in_grid]],
- // uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
- // uint3 tpitg[[thread_position_in_threadgroup]],
- // uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
- // const int64_t KHW = (int64_t)args.KHW;
- //
- // const int64_t d = tgpig[0] / args.CHW;
- // const int64_t chw = tgpig[0] % args.CHW;
- // const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
- // const int64_t HW = tgpig[0] % KHW;
- //
- // const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
- // if (tpitg_0 >= args.N) {
- // return;
- // }
- //
- // const int64_t tpitg_1 = HW / args.KW;
- // const int64_t tpitg_2 = HW % args.KW;
- //
- // const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
- // const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
- //
- // const int64_t offset_dst =
- // (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
- // (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
- //
- // device T * pdst = (device T *) (dst);
- //
- // if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
- // pdst[offset_dst] = 0.0f;
- // } else {
- // const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
- // pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
- // }
- //}
- //
- //template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
- //template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
- template <typename TK>
- kernel void kernel_conv_2d(
- constant ggml_metal_kargs_conv_2d & args,
- device const char * weights,
- device const char * src,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tgpg[[threadgroups_per_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const uint threads_per_tg = ntg.x * ntg.y * ntg.z;
- const uint tg_index = (tgpig.z * tgpg.y + tgpig.y) * tgpg.x + tgpig.x;
- const uint local_thread = tpitg.z * (ntg.x * ntg.y) + tpitg.y * ntg.x + tpitg.x;
- const uint thread_index = tg_index * threads_per_tg + local_thread;
- const uint64_t total_threads = (uint64_t) threads_per_tg * tgpg.x * tgpg.y * tgpg.z;
- const uint64_t total_outputs = (uint64_t) args.N * args.OC * args.OH * args.OW;
- for (uint64_t index = thread_index; index < total_outputs; index += total_threads) {
- uint64_t tmp = index;
- const int32_t ow = tmp % args.OW; tmp /= args.OW;
- const int32_t oh = tmp % args.OH; tmp /= args.OH;
- const int32_t oc = tmp % args.OC; tmp /= args.OC;
- const int32_t n = tmp;
- float acc = 0.0f;
- const int32_t base_x = ow*args.s0 - args.p0;
- const int32_t base_y = oh*args.s1 - args.p1;
- int32_t ky_start = 0;
- if (base_y < 0) {
- ky_start = (-base_y + args.d1 - 1)/args.d1;
- }
- int32_t ky_end = args.KH;
- const int32_t y_max = args.IH - 1 - base_y;
- if (y_max < 0) {
- ky_end = ky_start;
- } else if (base_y + (args.KH - 1)*args.d1 >= args.IH) {
- ky_end = min(ky_end, y_max/args.d1 + 1);
- }
- int32_t kx_start = 0;
- if (base_x < 0) {
- kx_start = (-base_x + args.d0 - 1)/args.d0;
- }
- int32_t kx_end = args.KW;
- const int32_t x_max = args.IW - 1 - base_x;
- if (x_max < 0) {
- kx_end = kx_start;
- } else if (base_x + (args.KW - 1)*args.d0 >= args.IW) {
- kx_end = min(kx_end, x_max/args.d0 + 1);
- }
- if (ky_start < ky_end && kx_start < kx_end) {
- const uint64_t src_base_n = (uint64_t) n * args.nb13;
- const uint64_t w_base_oc = (uint64_t) oc * args.nb03;
- for (int32_t ic = 0; ic < args.IC; ++ic) {
- const uint64_t src_base_nc = src_base_n + (uint64_t) ic * args.nb12;
- const uint64_t w_base_ocic = w_base_oc + (uint64_t) ic * args.nb02;
- for (int32_t ky = ky_start; ky < ky_end; ++ky) {
- const int32_t iy = base_y + ky*args.d1;
- const uint64_t src_base_row = src_base_nc + (uint64_t) iy * args.nb11;
- const uint64_t w_base_row = w_base_ocic + (uint64_t) ky * args.nb01;
- for (int32_t kx = kx_start; kx < kx_end; ++kx) {
- const int32_t ix = base_x + kx*args.d0;
- const uint64_t src_offs = src_base_row + (uint64_t) ix * args.nb10;
- const uint64_t w_offs = w_base_row + (uint64_t) kx * args.nb00;
- const float x = *(device const float *)(src + src_offs);
- const float w = (float) (*(device const TK *)(weights + w_offs));
- acc += x * w;
- }
- }
- }
- }
- const uint64_t dst_offs =
- (uint64_t) n * args.nb3 +
- (uint64_t) oc * args.nb2 +
- (uint64_t) oh * args.nb1 +
- (uint64_t) ow * args.nb0;
- *(device float *)(dst + dst_offs) = acc;
- }
- }
- template [[host_name("kernel_conv_2d_f32_f32")]]
- kernel void kernel_conv_2d<float>(
- constant ggml_metal_kargs_conv_2d & args,
- device const char * weights,
- device const char * src,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tgpg[[threadgroups_per_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]);
- template [[host_name("kernel_conv_2d_f16_f32")]]
- kernel void kernel_conv_2d<half>(
- constant ggml_metal_kargs_conv_2d & args,
- device const char * weights,
- device const char * src,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tgpg[[threadgroups_per_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]);
- typedef void (conv_transpose_1d_t)(
- constant ggml_metal_kargs_conv_transpose_1d & args,
- device const float * src0,
- device const float * src1,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tgpg[[threadgroups_per_grid]]);
- template <typename T>
- kernel void kernel_conv_transpose_1d(
- constant ggml_metal_kargs_conv_transpose_1d & args,
- device const T * src0,
- device const float * src1,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tgpg[[threadgroups_per_grid]]) {
- float v = 0.0f;
- for (int64_t c = 0; c < args.IC; c++) {
- const int32_t kernel_offset = c * tgpg[1] * args.K + args.K * tgpig[1];
- const int32_t input_offset = c * args.IL;
- for (int64_t i = 0; i < args.IL; i++) {
- if (tgpig[0] >= i * args.s0 && tgpig[0] < i * args.s0 + args.K) {
- v += src0[kernel_offset + tgpig[0] - i * args.s0] * src1[input_offset + i];
- }
- }
- }
- device float * dst_ptr = (device float *) (dst + tgpig[0] * args.nb0 + tgpig[1] * args.nb1);
- dst_ptr[0] = v;
- }
- template [[host_name("kernel_conv_transpose_1d_f32_f32")]]
- kernel void kernel_conv_transpose_1d<float>(
- constant ggml_metal_kargs_conv_transpose_1d & args,
- device const float * src0,
- device const float * src1,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tgpg[[threadgroups_per_grid]]);
- template [[host_name("kernel_conv_transpose_1d_f16_f32")]]
- kernel void kernel_conv_transpose_1d<half>(
- constant ggml_metal_kargs_conv_transpose_1d & args,
- device const half * src0,
- device const float * src1,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tgpg[[threadgroups_per_grid]]);
- typedef void (conv_transpose_2d_t)(
- constant ggml_metal_kargs_conv_transpose_2d & args,
- device const float * src0,
- device const float * src1,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tgpg[[threadgroups_per_grid]]);
- template <typename T>
- kernel void kernel_conv_transpose_2d(
- constant ggml_metal_kargs_conv_transpose_2d & args,
- device const T * src0,
- device const float * src1,
- device char * dst,
- threadgroup float * shared_sum [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t out_x = tgpig[0];
- const int64_t out_y = tgpig[1];
- const int64_t out_c = tgpig[2];
- const int64_t kw = tpitg[0];
- const int64_t kh = tpitg[1];
- float v = 0.0f;
- for (int64_t in_c = 0; in_c < args.IC; in_c++) {
- int64_t in_y = out_y - kh;
- if (in_y < 0 || in_y % args.s0) continue;
- in_y /= args.s0;
- if (in_y >= args.IH) continue;
- int64_t in_x = out_x - kw;
- if (in_x < 0 || in_x % args.s0) continue;
- in_x /= args.s0;
- if (in_x >= args.IW) continue;
- const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x;
- const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw;
- v += (float)src0[kernel_idx] * src1[input_idx];
- }
- const uint tid = tpitg.y * ntg.x + tpitg.x;
- shared_sum[tid] = v;
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (tid == 0) {
- float total = 0.0f;
- const uint num_threads = ntg.x * ntg.y;
- for (uint i = 0; i < num_threads; i++) {
- total += shared_sum[i];
- }
- device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2);
- dst_ptr[0] = total;
- }
- }
- template [[host_name("kernel_conv_transpose_2d_f32_f32")]]
- kernel void kernel_conv_transpose_2d<float>(
- constant ggml_metal_kargs_conv_transpose_2d & args,
- device const float * src0,
- device const float * src1,
- device char * dst,
- threadgroup float * shared_sum [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]);
- template [[host_name("kernel_conv_transpose_2d_f16_f32")]]
- kernel void kernel_conv_transpose_2d<half>(
- constant ggml_metal_kargs_conv_transpose_2d & args,
- device const half * src0,
- device const float * src1,
- device char * dst,
- threadgroup float * shared_sum [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]);
- kernel void kernel_upscale_f32(
- constant ggml_metal_kargs_upscale & args,
- device const char * src0,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i3 = tgpig.z;
- const int64_t i2 = tgpig.y;
- const int64_t i1 = tgpig.x;
- const int64_t i03 = i3/args.sf3;
- const int64_t i02 = i2/args.sf2;
- const int64_t i01 = i1/args.sf1;
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
- const int64_t i00 = i0/args.sf0;
- device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
- device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
- dst_ptr[0] = src0_ptr[0];
- }
- }
- kernel void kernel_pad_f32(
- constant ggml_metal_kargs_pad & args,
- device const char * src0,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i3 = tgpig.z;
- const int64_t i2 = tgpig.y;
- const int64_t i1 = tgpig.x;
- const int64_t i03 = i3;
- const int64_t i02 = i2;
- const int64_t i01 = i1;
- device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
- device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
- if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
- if (i0 < args.ne00) {
- dst_ptr[i0] = src0_ptr[i0];
- } else {
- dst_ptr[i0] = 0.0f;
- }
- }
- return;
- }
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
- dst_ptr[i0] = 0.0f;
- }
- }
- kernel void kernel_pad_reflect_1d_f32(
- constant ggml_metal_kargs_pad_reflect_1d & args,
- device const char * src0,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tgpg[[threadgroups_per_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i3 = tgpig.z;
- const int64_t i2 = tgpig.y;
- const int64_t i1 = tgpig.x;
- const int64_t i03 = i3;
- const int64_t i02 = i2;
- const int64_t i01 = i1;
- device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
- device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
- if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
- if (i0 < args.p0) {
- dst_ptr[i0] = src0_ptr[args.p0 - i0];
- } else if (i0 < args.ne0 - args.p1) {
- dst_ptr[i0] = src0_ptr[i0 - args.p0];
- } else {
- dst_ptr[i0] = src0_ptr[(args.ne0 - args.p1 - args.p0) - (args.p1 + 1 - (args.ne0 - i0)) - 1];
- }
- }
- }
- }
- kernel void kernel_arange_f32(
- constant ggml_metal_kargs_arange & args,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- device float * dst_ptr = (device float *) dst;
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
- dst_ptr[i0] = args.start + args.step * i0;
- }
- }
- kernel void kernel_timestep_embedding_f32(
- constant ggml_metal_kargs_timestep_embedding & args,
- device const char * src0,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- int i = tgpig.x;
- device float * embed_data = (device float *)(dst + i*args.nb1);
- int half_ = args.dim / 2;
- for (int j = tpitg.x; j < half_; j += ntg.x) {
- float timestep = ((device float *)src0)[i];
- float freq = (float)exp(-log((float)args.max_period) * j / half_);
- float arg = timestep * freq;
- embed_data[j ] = cos(arg);
- embed_data[j + half_] = sin(arg);
- }
- if (args.dim % 2 != 0 && tpitg.x == 0) {
- embed_data[2 * half_] = 0.f;
- }
- }
- // bitonic sort implementation following the CUDA kernels as reference
- typedef void (argsort_t)(
- constant ggml_metal_kargs_argsort & args,
- device const char * src0,
- device int32_t * dst,
- threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort3 tpitg[[thread_position_in_threadgroup]],
- ushort3 ntg[[threads_per_threadgroup]]);
- template<ggml_sort_order order>
- kernel void kernel_argsort_f32_i32(
- constant ggml_metal_kargs_argsort & args,
- device const char * src0,
- device int32_t * dst,
- threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort3 tpitg[[thread_position_in_threadgroup]],
- ushort3 ntg[[threads_per_threadgroup]]) {
- // bitonic sort
- const int col = tpitg[0];
- const int ib = tgpig[0] / args.ne01;
- const int i00 = ib*ntg.x;
- const int i01 = tgpig[0] % args.ne01;
- const int i02 = tgpig[1];
- const int i03 = tgpig[2];
- device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
- // initialize indices
- shmem_i32[col] = i00 + col;
- threadgroup_barrier(mem_flags::mem_threadgroup);
- for (int k = 2; k <= ntg.x; k *= 2) {
- for (int j = k / 2; j > 0; j /= 2) {
- int ixj = col ^ j;
- if (ixj > col) {
- if ((col & k) == 0) {
- if (shmem_i32[col] >= args.ne00 ||
- (shmem_i32[ixj] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
- src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]] :
- src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]]))
- ) {
- SWAP(shmem_i32[col], shmem_i32[ixj]);
- }
- } else {
- if (shmem_i32[ixj] >= args.ne00 ||
- (shmem_i32[col] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
- src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]] :
- src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]]))
- ) {
- SWAP(shmem_i32[col], shmem_i32[ixj]);
- }
- }
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- }
- }
- const int64_t i0 = ib*args.top_k;
- // copy the result to dst without the padding
- if (i0 + col < args.ne0 && col < args.top_k) {
- dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03;
- dst[col] = shmem_i32[col];
- }
- }
- template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
- template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
- typedef void (argsort_merge_t)(
- constant ggml_metal_kargs_argsort_merge & args,
- device const char * src0,
- device const int32_t * tmp,
- device int32_t * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort3 tpitg[[thread_position_in_threadgroup]],
- ushort3 ntg[[threads_per_threadgroup]]);
- template<ggml_sort_order order>
- kernel void kernel_argsort_merge_f32_i32(
- constant ggml_metal_kargs_argsort_merge & args,
- device const char * src0,
- device const int32_t * tmp,
- device int32_t * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort3 tpitg[[thread_position_in_threadgroup]],
- ushort3 ntg[[threads_per_threadgroup]]) {
- const int im = tgpig[0] / args.ne01;
- const int i01 = tgpig[0] % args.ne01;
- const int i02 = tgpig[1];
- const int i03 = tgpig[2];
- const int start = im * (2 * args.len);
- const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start)));
- const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len)));
- const int total = len0 + len1;
- device const int32_t * tmp0 = tmp + start
- + i01*args.ne0
- + i02*args.ne0*args.ne01
- + i03*args.ne0*args.ne01*args.ne02;
- device const int32_t * tmp1 = tmp0 + args.len;
- dst += start
- + i01*args.top_k
- + i02*args.top_k*args.ne01
- + i03*args.top_k*args.ne01*args.ne02;
- device const float * src0_row = (device const float *)(src0
- + args.nb01*i01
- + args.nb02*i02
- + args.nb03*i03);
- if (total == 0) {
- return;
- }
- const int chunk = (total + ntg.x - 1) / ntg.x;
- const int k0 = tpitg.x * chunk;
- const int k1 = MIN(MIN(k0 + chunk, total), args.top_k);
- if (k0 >= args.top_k) {
- return;
- }
- if (k0 >= total) {
- return;
- }
- int low = k0 > len1 ? k0 - len1 : 0;
- int high = MIN(k0, len0);
- // binary-search partition (i, j) such that i + j = k
- while (low < high) {
- const int mid = (low + high) >> 1;
- const int32_t idx0 = tmp0[mid];
- const int32_t idx1 = tmp1[k0 - mid - 1];
- const float val0 = src0_row[idx0];
- const float val1 = src0_row[idx1];
- bool take_left;
- if (order == GGML_SORT_ORDER_ASC) {
- take_left = (val0 <= val1);
- } else {
- take_left = (val0 >= val1);
- }
- if (take_left) {
- low = mid + 1;
- } else {
- high = mid;
- }
- }
- int i = low;
- int j = k0 - i;
- // keep the merge fronts into registers
- int32_t idx0 = 0;
- float val0 = 0.0f;
- if (i < len0) {
- idx0 = tmp0[i];
- val0 = src0_row[idx0];
- }
- int32_t idx1 = 0;
- float val1 = 0.0f;
- if (j < len1) {
- idx1 = tmp1[j];
- val1 = src0_row[idx1];
- }
- for (int k = k0; k < k1; ++k) {
- int32_t out_idx;
- if (i >= len0) {
- while (k < k1) {
- dst[k++] = tmp1[j++];
- }
- break;
- } else if (j >= len1) {
- while (k < k1) {
- dst[k++] = tmp0[i++];
- }
- break;
- } else {
- bool take_left;
- if (order == GGML_SORT_ORDER_ASC) {
- take_left = (val0 <= val1);
- } else {
- take_left = (val0 >= val1);
- }
- if (take_left) {
- out_idx = idx0;
- ++i;
- if (i < len0) {
- idx0 = tmp0[i];
- val0 = src0_row[idx0];
- }
- } else {
- out_idx = idx1;
- ++j;
- if (j < len1) {
- idx1 = tmp1[j];
- val1 = src0_row[idx1];
- }
- }
- }
- dst[k] = out_idx;
- }
- }
- template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
- template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
- kernel void kernel_leaky_relu_f32(
- constant ggml_metal_kargs_leaky_relu & args,
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- const float x = src0[tpig];
- dst[tpig] = x > 0.0f ? x : x * args.slope;
- }
- kernel void kernel_leaky_relu_f32_4(
- constant ggml_metal_kargs_leaky_relu & args,
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- const float4 x = src0[tpig];
- dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope);
- }
- constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]];
- constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]];
- // pad the last chunk of C elements of k and v into a an extra pad buffer
- kernel void kernel_flash_attn_ext_pad(
- constant ggml_metal_kargs_flash_attn_ext_pad & args,
- device const char * k,
- device const char * v,
- device const char * mask,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiitg[[thread_index_in_threadgroup]],
- ushort3 ntg[[threads_per_threadgroup]]) {
- const int32_t C = FC_flash_attn_ext_pad_ncpsg;
- device char * k_pad = dst;
- device char * v_pad = k_pad + args.nb11*C*args.ne_12_2*args.ne_12_3;
- device char * mask_pad = v_pad + args.nb21*C*args.ne_12_2*args.ne_12_3;
- const int32_t icp = args.ne11 % C;
- const int32_t ic0 = args.ne11 - icp;
- const int32_t i1 = tgpig[0];
- const int32_t i2 = tgpig[1];
- const int32_t i3 = tgpig[2];
- if (i2 < args.ne_12_2 && i3 < args.ne_12_3) {
- device const char * k_src = k + args.nb11*(ic0 + i1) + args.nb12*i2 + args.nb13*i3;
- device const char * v_src = v + args.nb21*(ic0 + i1) + args.nb22*i2 + args.nb23*i3;
- device char * k_dst = k_pad + args.nb11*i1 + args.nb11*C*i2 + args.nb11*C*args.ne_12_2*i3;
- device char * v_dst = v_pad + args.nb21*i1 + args.nb21*C*i2 + args.nb21*C*args.ne_12_2*i3;
- if (i1 >= icp) {
- // here it is not important the exact value that will be used as we rely on masking out the scores in the attention
- for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) {
- k_dst[i] = 0;
- }
- for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) {
- v_dst[i] = 0;
- }
- } else {
- for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) {
- k_dst[i] = k_src[i];
- }
- for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) {
- v_dst[i] = v_src[i];
- }
- }
- }
- if (FC_flash_attn_ext_pad_has_mask) {
- if (i2 < args.ne32 && i3 < args.ne33) {
- for (int ib = i1; ib < args.ne31; ib += C) {
- device const half * mask_src = (device const half *)(mask + args.nb31*ib + args.nb32*i2 + args.nb33*i3) + ic0;
- device half * mask_dst = (device half *)(mask_pad) + C*ib + C*args.ne31*i2 + C*args.ne31*args.ne32*i3;
- for (int i = tiitg; i < C; i += ntg.x) {
- if (i >= icp) {
- mask_dst[i] = -MAXHALF;
- } else {
- mask_dst[i] = mask_src[i];
- }
- }
- }
- }
- }
- }
- constant int32_t FC_flash_attn_ext_blk_nqptg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 24)]];
- constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 25)]];
- // scan the blocks of the mask that are not masked
- // 0 - masked (i.e. full of -INF, skip)
- // 1 - not masked (i.e. at least one element of the mask is not -INF)
- kernel void kernel_flash_attn_ext_blk(
- constant ggml_metal_kargs_flash_attn_ext_blk & args,
- device const char * mask,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]]) {
- // block size C x Q
- const int32_t Q = FC_flash_attn_ext_blk_nqptg;
- const int32_t C = FC_flash_attn_ext_blk_ncpsg;
- constexpr short NW = N_SIMDWIDTH;
- const int32_t i3 = tgpig[2]/args.ne32;
- const int32_t i2 = tgpig[2]%args.ne32;
- const int32_t i1 = tgpig[1];
- const int32_t i0 = tgpig[0];
- char res = i0*C + C > args.ne30 ? 1 : 0;
- device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg;
- // fast route
- if (res == 0) {
- if (simd_max(*mask_src) > -MAXHALF/2) {
- res = 1;
- }
- }
- // detailed check of the elements of the block
- if ((C > NW || Q > 1) && res == 0) {
- half m = -MAXHALF;
- FOR_UNROLL (short j = 0; j < Q; ++j) {
- FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) {
- m = max(m, mask_src[ii*NW]);
- }
- mask_src += args.nb31/2;
- }
- if (simd_max(m) > -MAXHALF/2) {
- res = 1;
- }
- }
- const int32_t nblk1 = ((args.ne01 + Q - 1)/Q);
- const int32_t nblk0 = ((args.ne30 + C - 1)/C);
- if (tiisg == 0) {
- dst[((i3*args.ne32 + i2)*nblk1 + i1)*nblk0 + i0] = res;
- }
- }
- constant bool FC_flash_attn_ext_has_mask [[function_constant(FC_FLASH_ATTN_EXT + 0)]];
- constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]];
- constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT + 2)]];
- constant bool FC_flash_attn_ext_has_scap [[function_constant(FC_FLASH_ATTN_EXT + 3)]];
- constant bool FC_flash_attn_ext_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT + 4)]];
- constant bool FC_flash_attn_ext_bc_mask [[function_constant(FC_FLASH_ATTN_EXT + 10)]];
- //constant float FC_flash_attn_ext_scale [[function_constant(FC_FLASH_ATTN_EXT + 10)]];
- //constant float FC_flash_attn_ext_max_bias [[function_constant(FC_FLASH_ATTN_EXT + 11)]];
- //constant float FC_flash_attn_ext_logit_softcap [[function_constant(FC_FLASH_ATTN_EXT + 12)]];
- constant int32_t FC_flash_attn_ext_ns10 [[function_constant(FC_FLASH_ATTN_EXT + 20)]];
- constant int32_t FC_flash_attn_ext_ns20 [[function_constant(FC_FLASH_ATTN_EXT + 21)]];
- constant int32_t FC_flash_attn_ext_nsg [[function_constant(FC_FLASH_ATTN_EXT + 22)]];
- // ref: https://arxiv.org/pdf/2307.08691.pdf
- template<
- typename q_t, // query types in shared memory
- typename q4_t,
- typename q8x8_t,
- typename k_t, // key types in shared memory
- typename k4x4_t,
- typename k8x8_t,
- typename v_t, // value types in shared memory
- typename v4x4_t,
- typename v8x8_t,
- typename qk_t, // Q*K types
- typename qk8x8_t,
- typename s_t, // soft-max types
- typename s2_t,
- typename s8x8_t,
- typename o_t, // attention accumulation types
- typename o4_t,
- typename o8x8_t,
- typename kd4x4_t, // key type in device memory
- short nl_k,
- void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
- typename vd4x4_t, // value type in device memory
- short nl_v,
- void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
- short DK, // K head size
- short DV, // V head size
- short Q, // queries per threadgroup
- short C, // cache items per threadgroup
- short NSG> // number of simd groups
- void kernel_flash_attn_ext_impl(
- constant ggml_metal_kargs_flash_attn_ext & args,
- device const char * q,
- device const char * k,
- device const char * v,
- device const char * mask,
- device const char * sinks,
- device const char * pad,
- device const char * blk,
- device char * dst,
- threadgroup half * shmem_f16,
- uint3 tgpig,
- ushort tiisg,
- ushort sgitg) {
- const ushort iq3 = tgpig[2];
- const ushort iq2 = tgpig[1];
- const ushort iq1 = tgpig[0]*Q;
- #define NS10 (FC_flash_attn_ext_ns10)
- #define NS20 (FC_flash_attn_ext_ns20)
- // note: I had some concerns that using this instead of the ugly macros above was affecting performance
- // need to re-check carefully and if no regressions are observerd - remove the macros
- // the concerns is that maybe using const variables requires extra registers? but not sure if the compiler
- // is clever enough to avoid this. unfortunately, using constexpr is not possible with FC
- //const short NS10 = FC_flash_attn_ext_ns10;
- //const short NS20 = FC_flash_attn_ext_ns20;
- constexpr short KV = 8;
- constexpr short DK4 = DK/4;
- constexpr short DK8 = DK/8;
- constexpr short DK16 = DK/16;
- constexpr short DV4 = DV/4;
- //constexpr short DV8 = DV/8;
- constexpr short DV16 = DV/16;
- constexpr short PV = PAD2(DV, 64);
- constexpr short PV4 = PV/4;
- constexpr short PV8 = PV/8;
- //constexpr short PV16 = PV/16;
- constexpr short NW = N_SIMDWIDTH;
- constexpr short NQ = Q/NSG;
- constexpr short SH = 2*C; // shared memory per simdgroup (s_t == float)
- constexpr short TS = 2*SH;
- constexpr short T = DK + 2*PV; // shared memory size per query in (half)
- threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*T); // holds the query data
- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*T); // same as above but in q4_t
- threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*T + Q*DK); // the result for all queries in 8x8 matrices (the O matrix from the paper)
- threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*T + Q*DK);
- threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + Q*T); // scratch buffer for attention, mask and diagonal matrix
- threadgroup s2_t * ss2 = (threadgroup s2_t *) (shmem_f16 + Q*T); // same as above but in s2_t
- threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // scratch buffer to load K in shared memory
- threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // same as above but in k4x4_t
- threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // scratch buffer to load V in shared memory
- threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // same as above but in v4x4_t
- // mask storage in shared mem
- threadgroup half2 * sm2 = (threadgroup half2 *) (shmem_f16 + Q*T + 2*C);
- // per-query mask pointers
- device const half2 * pm2[NQ];
- FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
- const short j = jj*NSG + sgitg;
- pm2[jj] = (device const half2 *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
- }
- {
- const int32_t nblk1 = ((args.ne01 + Q - 1)/Q);
- const int32_t nblk0 = ((args.ne11 + C - 1)/C);
- blk += (((iq3%args.ne33)*args.ne32 + (iq2%args.ne32))*nblk1 + iq1/Q)*nblk0;
- }
- {
- q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03;
- const short ikv2 = iq2/(args.ne02/args.ne_12_2);
- const short ikv3 = iq3/(args.ne03/args.ne_12_3);
- k += ikv2*args.nb12 + ikv3*args.nb13;
- v += ikv2*args.nb22 + ikv3*args.nb23;
- }
- // load heads from Q to shared memory
- FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
- const short j = jj*NSG + sgitg;
- device const float4 * q4 = (device const float4 *) ((device const char *) q + j*args.nb01);
- for (short i = tiisg; i < DK4; i += NW) {
- if (iq1 + j < args.ne01) {
- sq4[j*DK4 + i] = (q4_t) q4[i];
- } else {
- sq4[j*DK4 + i] = 0;
- }
- }
- }
- // zero out
- FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
- const short j = jj*NSG + sgitg;
- for (short i = tiisg; i < DV4; i += NW) {
- so4[j*PV4 + i] = 0;
- }
- for (short i = tiisg; i < SH; i += NW) {
- ss[j*SH + i] = 0.0f;
- }
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- float S[NQ] = { [0 ... NQ-1] = 0.0f };
- {
- float M[NQ] = { [0 ... NQ-1] = -FLT_MAX/2 };
- float slope = 1.0f;
- // ALiBi
- if (FC_flash_attn_ext_has_bias) {
- const short h = iq2;
- const float base = h < args.n_head_log2 ? args.m0 : args.m1;
- const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
- slope = pow(base, exph);
- }
- // loop over the KV cache
- // each simdgroup handles blocks of Q rows and C columns
- for (int ic0 = 0; ; ++ic0) {
- int ic = ic0*C;
- if (ic >= args.ne11) {
- break;
- }
- // the last partial chunk uses the pad buffer as source
- if (FC_flash_attn_ext_has_kvpad && ic + C > args.ne11) {
- k = pad;
- v = k + args.nb11*C*args.ne_12_2*args.ne_12_3;
- mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3;
- const short ikv2 = iq2/(args.ne02/args.ne_12_2);
- const short ikv3 = iq3/(args.ne03/args.ne_12_3);
- k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C;
- v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;
- if (!FC_flash_attn_ext_has_mask) {
- threadgroup half * sm = (threadgroup half *) (sm2);
- FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
- const short j = jj*NSG + sgitg;
- for (short i = tiisg; i < C; i += NW) {
- if (ic + i >= args.ne11) {
- sm[2*j*SH + i] = -MAXHALF;
- }
- }
- }
- } else {
- FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
- const short j = jj*NSG + sgitg;
- pm2[jj] = (device const half2 *) ((device const half *) mask +
- (iq1 + j)*C +
- (iq2%args.ne32)*(C*args.ne31) +
- (iq3%args.ne33)*(C*args.ne31*args.ne32));
- }
- }
- ic = 0;
- }
- // read the mask into shared mem
- if (FC_flash_attn_ext_has_mask) {
- if (blk[ic0] == 0) {
- FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
- pm2[jj] += NW;
- }
- continue;
- }
- FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
- const short j = jj*NSG + sgitg;
- if (FC_flash_attn_ext_bc_mask) {
- sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
- } else {
- sm2[j*SH + tiisg] = pm2[jj][tiisg];
- }
- pm2[jj] += NW;
- }
- #if 0
- // note: old -INF block optimization - obsoleted by pre-computing non-masked blocks
- threadgroup_barrier(mem_flags::mem_threadgroup);
- // used to detect blocks full of -INF
- // skip only when the entire threadgroup is masked
- half2 smax2(-MAXHALF/2, -MAXHALF/2);
- FOR_UNROLL (short j = 0; j < Q; ++j) {
- smax2 = max(smax2, sm2[j*SH + tiisg]);
- }
- smax2 = simd_max(smax2);
- if (max(smax2[0], smax2[1]) <= -MAXHALF/2) {
- // this barrier is important
- threadgroup_barrier(mem_flags::mem_threadgroup);
- continue;
- }
- #endif
- }
- // Q*K^T
- // this is compile-time check, so it does not have runtime overhead
- if (is_same<kd4x4_t, k4x4_t>::value) {
- // we can read directly from global memory
- device const k_t * pk = (device const k_t *) (k + ic*args.nb11);
- threadgroup const q_t * pq = sq;
- threadgroup s_t * ps = ss;
- pk += sgitg*(8*NS10);
- ps += sgitg*(8*1);
- static_assert((C/8) % NSG == 0, "");
- constexpr short NC = (C/8)/NSG;
- // note: do not unroll for large heads
- #pragma unroll (DK <= 64 ? NC : 1)
- for (short cc = 0; cc < NC; ++cc) {
- qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
- if (DK % 16 != 0) {
- k8x8_t mk;
- q8x8_t mq;
- FOR_UNROLL (short i = 0; i < DK8; ++i) {
- simdgroup_barrier(mem_flags::mem_none);
- simdgroup_load(mk, pk + 8*i, NS10, 0, true);
- simdgroup_load(mq, pq + 8*i, DK);
- simdgroup_barrier(mem_flags::mem_none);
- simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
- }
- } else {
- k8x8_t mk[2];
- q8x8_t mq[2];
- FOR_UNROLL (short i = 0; i < DK8/2; ++i) {
- simdgroup_barrier(mem_flags::mem_none);
- simdgroup_load(mq[0], pq + 0*8 + 16*i, DK);
- simdgroup_load(mq[1], pq + 1*8 + 16*i, DK);
- simdgroup_load(mk[0], pk + 0*8 + 16*i, NS10, 0, true);
- simdgroup_load(mk[1], pk + 1*8 + 16*i, NS10, 0, true);
- simdgroup_barrier(mem_flags::mem_none);
- simdgroup_multiply_accumulate(mqk, mq[0], mk[0], mqk);
- simdgroup_multiply_accumulate(mqk, mq[1], mk[1], mqk);
- }
- }
- simdgroup_store(mqk, ps, SH, 0, false);
- pk += 8*(NSG*NS10);
- ps += 8*(NSG);
- }
- } else {
- // TODO: this is the quantized K cache branch - not optimized yet
- for (short ccc = 0; ccc < (C/8)/NSG; ++ccc) {
- const short cc = ccc*NSG + sgitg;
- const short tx = tiisg%4;
- const short ty = tiisg/4;
- qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
- for (short ii = 0; ii < DK16; ii += 4) {
- device const kd4x4_t * pk4x4 = (device const kd4x4_t *) (k + ((ic + 8*cc + ty)*args.nb11));
- if (DK16%4 == 0) {
- // the head is evenly divisible by 4*16 = 64, so no need for bound checks
- {
- k4x4_t tmp;
- deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
- sk4x4[4*ty + tx] = tmp;
- }
- simdgroup_barrier(mem_flags::mem_threadgroup);
- FOR_UNROLL (short k = 0; k < 4; ++k) {
- k8x8_t mk;
- q8x8_t mq;
- simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
- simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
- simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
- simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
- simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
- simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
- }
- } else {
- if (ii + tx < DK16) {
- k4x4_t tmp;
- deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
- sk4x4[4*ty + tx] = tmp;
- }
- simdgroup_barrier(mem_flags::mem_threadgroup);
- for (short k = 0; k < 4 && ii + k < DK16; ++k) {
- k8x8_t mk;
- q8x8_t mq;
- simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
- simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
- simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
- simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
- simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
- simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
- }
- }
- }
- simdgroup_store(mqk, ss + 8*cc, SH, 0, false);
- }
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- // online softmax
- FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
- const short j = jj*NSG + sgitg;
- const float m = M[jj];
- // scale and apply the logitcap / mask
- float2 s2 = ss2[j*SH/2 + tiisg]*args.scale;
- if (FC_flash_attn_ext_has_scap) {
- s2 = args.logit_softcap*precise::tanh(s2);
- }
- // mqk = mqk + slope*mask
- if (FC_flash_attn_ext_has_bias) {
- s2 += s2_t(sm2[j*SH + tiisg])*slope;
- } else {
- s2 += s2_t(sm2[j*SH + tiisg]);
- }
- M[jj] = simd_max(max(M[jj], max(s2[0], s2[1])));
- const float ms = exp(m - M[jj]);
- const float2 vs2 = exp(s2 - M[jj]);
- S[jj] = S[jj]*ms + simd_sum(vs2[0] + vs2[1]);
- // the P matrix from the paper (Q rows, C columns)
- ss2[j*SH/2 + tiisg] = vs2;
- if (DV4 % NW == 0) {
- FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) {
- const short i = ii*NW + tiisg;
- so4[j*PV4 + i] *= ms;
- }
- } else {
- for (short i = tiisg; i < DV4; i += NW) {
- so4[j*PV4 + i] *= ms;
- }
- }
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- // O = O + (Q*K^T)*V
- {
- // we can read directly from global memory
- if (is_same<vd4x4_t, v4x4_t>::value) {
- static_assert(PV8 % NSG == 0, "");
- constexpr short NO = PV8/NSG;
- o8x8_t lo[NO];
- {
- auto sot = so + 8*sgitg;
- FOR_UNROLL (short ii = 0; ii < NO; ++ii) {
- simdgroup_load(lo[ii], sot, PV, 0, false);
- sot += 8*NSG;
- }
- }
- {
- device const v_t * pv = (device const v_t *) (v + ic*args.nb21);
- pv += 8*sgitg;
- if (DV <= 64) {
- FOR_UNROLL (short cc = 0; cc < C/8; ++cc) {
- s8x8_t vs;
- simdgroup_load(vs, ss + 8*cc, SH, 0, false);
- FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) {
- v8x8_t mv[2];
- simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG, NS20, 0, false);
- simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG, NS20, 0, false);
- simdgroup_multiply_accumulate(lo[2*ii + 0], vs, mv[0], lo[2*ii + 0]);
- simdgroup_multiply_accumulate(lo[2*ii + 1], vs, mv[1], lo[2*ii + 1]);
- }
- pv += 8*NS20;
- }
- } else {
- FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) {
- s8x8_t vs[2];
- simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false);
- simdgroup_load(vs[1], ss + 16*cc + 8, SH, 0, false);
- FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) {
- v8x8_t mv[4];
- simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false);
- simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false);
- simdgroup_load(mv[2], pv + 0*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false);
- simdgroup_load(mv[3], pv + 8*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false);
- simdgroup_multiply_accumulate(lo[2*ii + 0], vs[0], mv[0], lo[2*ii + 0]);
- simdgroup_multiply_accumulate(lo[2*ii + 1], vs[0], mv[1], lo[2*ii + 1]);
- simdgroup_multiply_accumulate(lo[2*ii + 0], vs[1], mv[2], lo[2*ii + 0]);
- simdgroup_multiply_accumulate(lo[2*ii + 1], vs[1], mv[3], lo[2*ii + 1]);
- }
- pv += 2*8*NS20;
- }
- }
- }
- {
- auto sot = so + 8*sgitg;
- FOR_UNROLL (short ii = 0; ii < NO; ++ii) {
- simdgroup_store(lo[ii], sot, PV, 0, false);
- sot += 8*NSG;
- }
- }
- } else {
- // TODO: this is the quantized V cache branch - not optimized yet
- const short tx = tiisg%4;
- const short ty = tiisg/4;
- for (short cc = 0; cc < C/8; ++cc) {
- s8x8_t vs;
- simdgroup_load(vs, ss + 8*cc, SH, 0, false);
- for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) {
- device const vd4x4_t * pv4x4 = (device const vd4x4_t *) (v + ((ic + 8*cc + ty)*args.nb21));
- if (DV16%4 == 0) {
- // no need for bound checks
- {
- v4x4_t tmp;
- deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
- sv4x4[4*ty + tx] = tmp;
- }
- simdgroup_barrier(mem_flags::mem_threadgroup);
- FOR_UNROLL (short k = 0; k < 4; ++k) {
- v8x8_t mv[2];
- o8x8_t lo[2];
- simdgroup_load(mv[0], sv + 16*k + 0*8, 4*16, 0, false);
- simdgroup_load(mv[1], sv + 16*k + 1*8, 4*16, 0, false);
- simdgroup_load(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);
- simdgroup_load(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);
- simdgroup_multiply_accumulate(lo[0], vs, mv[0], lo[0]);
- simdgroup_multiply_accumulate(lo[1], vs, mv[1], lo[1]);
- simdgroup_store(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);
- simdgroup_store(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);
- }
- } else {
- if (ii + tx < DV16) {
- v4x4_t tmp;
- deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
- sv4x4[4*ty + tx] = tmp;
- }
- simdgroup_barrier(mem_flags::mem_threadgroup);
- for (short k = 0; k < 4 && ii + k < DV16; ++k) {
- v8x8_t mv[2];
- o8x8_t lo[2];
- simdgroup_load(mv[0], sv + 16*k + 0*8, 4*16, 0, false);
- simdgroup_load(mv[1], sv + 16*k + 1*8, 4*16, 0, false);
- simdgroup_load(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);
- simdgroup_load(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);
- simdgroup_multiply_accumulate(lo[0], vs, mv[0], lo[0]);
- simdgroup_multiply_accumulate(lo[1], vs, mv[1], lo[1]);
- simdgroup_store(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);
- simdgroup_store(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);
- }
- }
- }
- }
- }
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- }
- if (FC_flash_attn_ext_has_sinks) {
- FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
- const short j = jj*NSG + sgitg;
- const float m = M[jj];
- const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
- M[jj] = simd_max(max(M[jj], s));
- const float ms = exp(m - M[jj]);
- const float vs = exp(s - M[jj]);
- S[jj] = S[jj]*ms + simd_sum(vs);
- for (short i = tiisg; i < DV4; i += NW) {
- so4[j*PV4 + i] *= ms;
- }
- }
- }
- }
- // store to global memory
- for (short jj = 0; jj < NQ; ++jj) {
- const short j = jj*NSG + sgitg;
- if (iq1 + j >= args.ne01) {
- break;
- }
- device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4;
- const float scale = S[jj] == 0.0 ? 0.0f : 1.0f/S[jj];
- if (DV4 % NW == 0) {
- FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) {
- const short i = ii*NW + tiisg;
- dst4[i] = (float4) so4[j*PV4 + i]*scale;
- }
- } else {
- for (short i = tiisg; i < DV4; i += NW) {
- dst4[i] = (float4) so4[j*PV4 + i]*scale;
- }
- }
- }
- #undef NS10
- #undef NS20
- }
- template<
- typename q_t, // query types in shared memory
- typename q4_t,
- typename q8x8_t,
- typename k_t, // key types in shared memory
- typename k4x4_t,
- typename k8x8_t,
- typename v_t, // value types in shared memory
- typename v4x4_t,
- typename v8x8_t,
- typename qk_t, // Q*K types
- typename qk8x8_t,
- typename s_t, // soft-max types
- typename s2_t,
- typename s8x8_t,
- typename o_t, // attention accumulation types
- typename o4_t,
- typename o8x8_t,
- typename kd4x4_t, // key type in device memory
- short nl_k,
- void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
- typename vd4x4_t, // value type in device memory
- short nl_v,
- void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
- short DK, // K head size
- short DV, // V head size
- short Q = OP_FLASH_ATTN_EXT_NQPTG, // queries per threadgroup
- short C = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup
- kernel void kernel_flash_attn_ext(
- constant ggml_metal_kargs_flash_attn_ext & args,
- device const char * q,
- device const char * k,
- device const char * v,
- device const char * mask,
- device const char * sinks,
- device const char * pad,
- device const char * blk,
- device char * dst,
- threadgroup half * shmem_f16 [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- #define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C
- #define FWD_ARGS args, q, k, v, mask, sinks, pad, blk, dst, shmem_f16, tgpig, tiisg, sgitg
- switch (FC_flash_attn_ext_nsg) {
- // note: disabled cases to reduce library load time
- //case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
- //case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
- case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
- }
- #undef FWD_TMPL
- #undef FWD_ARGS
- }
- // TODO: this is quite ugly. in the future these types will be hardcoded in the kernel, but for now keep them as
- // template to be able to explore different combinations
- //
- #define FA_TYPES \
- half, half4, simdgroup_half8x8, \
- half, half4x4, simdgroup_half8x8, \
- half, half4x4, simdgroup_half8x8, \
- float, simdgroup_float8x8, \
- float, float2, simdgroup_float8x8, \
- float, float4, simdgroup_float8x8
- //half, half4, simdgroup_half8x8
- #define FA_TYPES_BF \
- bfloat, bfloat4, simdgroup_bfloat8x8, \
- bfloat, bfloat4x4, simdgroup_bfloat8x8, \
- bfloat, bfloat4x4, simdgroup_bfloat8x8, \
- float, simdgroup_float8x8, \
- float, float2, simdgroup_float8x8, \
- half, half4, simdgroup_half8x8
- //float, float4, simdgroup_float8x8
- #define FA_TYPES_F32 \
- half, half4, simdgroup_half8x8, \
- float, float4x4, simdgroup_float8x8, \
- float, float4x4, simdgroup_float8x8, \
- float, simdgroup_float8x8, \
- float, float2, simdgroup_float8x8, \
- float, float4, simdgroup_float8x8
- //half, half4, simdgroup_half8x8
- typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
- template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 32, 32>;
- template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 40, 40>;
- template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 64, 64>;
- template [[host_name("kernel_flash_attn_ext_f32_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 72, 72>;
- template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 80, 80>;
- template [[host_name("kernel_flash_attn_ext_f32_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 96, 96>;
- template [[host_name("kernel_flash_attn_ext_f32_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 112, 112>;
- template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 128, 128>;
- template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 192>;
- template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 128>;
- template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 256, 256>;
- template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 576, 512>;
- template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
- template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
- template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
- template [[host_name("kernel_flash_attn_ext_f16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 72, 72>;
- template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
- template [[host_name("kernel_flash_attn_ext_f16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96, 96>;
- template [[host_name("kernel_flash_attn_ext_f16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112, 112>;
- template [[host_name("kernel_flash_attn_ext_f16_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128, 128>;
- template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 192>;
- template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 128>;
- template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256, 256>;
- template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
- #if defined(GGML_METAL_HAS_BF16)
- template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 32, 32>;
- template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
- template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
- template [[host_name("kernel_flash_attn_ext_bf16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 72, 72>;
- template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
- template [[host_name("kernel_flash_attn_ext_bf16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
- template [[host_name("kernel_flash_attn_ext_bf16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
- template [[host_name("kernel_flash_attn_ext_bf16_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
- template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
- template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
- template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
- template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
- #endif
- template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 32, 32>;
- template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
- template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
- template [[host_name("kernel_flash_attn_ext_q4_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 72, 72>;
- template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
- template [[host_name("kernel_flash_attn_ext_q4_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96, 96>;
- template [[host_name("kernel_flash_attn_ext_q4_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112, 112>;
- template [[host_name("kernel_flash_attn_ext_q4_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128, 128>;
- template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 192>;
- template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>;
- template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
- template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
- template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
- template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
- template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
- template [[host_name("kernel_flash_attn_ext_q4_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 72, 72>;
- template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
- template [[host_name("kernel_flash_attn_ext_q4_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96, 96>;
- template [[host_name("kernel_flash_attn_ext_q4_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112, 112>;
- template [[host_name("kernel_flash_attn_ext_q4_1_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128, 128>;
- template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 192>;
- template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>;
- template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
- template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
- template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
- template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
- template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
- template [[host_name("kernel_flash_attn_ext_q5_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 72, 72>;
- template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
- template [[host_name("kernel_flash_attn_ext_q5_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96, 96>;
- template [[host_name("kernel_flash_attn_ext_q5_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112, 112>;
- template [[host_name("kernel_flash_attn_ext_q5_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128, 128>;
- template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 192>;
- template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>;
- template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
- template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
- template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
- template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
- template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
- template [[host_name("kernel_flash_attn_ext_q5_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 72, 72>;
- template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
- template [[host_name("kernel_flash_attn_ext_q5_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96, 96>;
- template [[host_name("kernel_flash_attn_ext_q5_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112, 112>;
- template [[host_name("kernel_flash_attn_ext_q5_1_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128, 128>;
- template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 192>;
- template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>;
- template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
- template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
- template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
- template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
- template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
- template [[host_name("kernel_flash_attn_ext_q8_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 72, 72>;
- template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
- template [[host_name("kernel_flash_attn_ext_q8_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96, 96>;
- template [[host_name("kernel_flash_attn_ext_q8_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112, 112>;
- template [[host_name("kernel_flash_attn_ext_q8_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128, 128>;
- template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 192>;
- template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>;
- template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>;
- template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
- #undef FA_TYPES
- #undef FA_TYPES_BF
- #undef FA_TYPES_F32
- constant bool FC_flash_attn_ext_vec_has_mask [[function_constant(FC_FLASH_ATTN_EXT_VEC + 0)]];
- constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]];
- constant bool FC_flash_attn_ext_vec_has_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 2)]];
- constant bool FC_flash_attn_ext_vec_has_scap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 3)]];
- constant bool FC_flash_attn_ext_vec_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT_VEC + 4)]];
- //constant float FC_flash_attn_ext_vec_scale [[function_constant(FC_FLASH_ATTN_EXT_VEC + 10)]];
- //constant float FC_flash_attn_ext_vec_max_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 11)]];
- //constant float FC_flash_attn_ext_vec_logit_softcap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 12)]];
- constant int32_t FC_flash_attn_ext_vec_ns10 [[function_constant(FC_FLASH_ATTN_EXT_VEC + 20)]];
- constant int32_t FC_flash_attn_ext_vec_ns20 [[function_constant(FC_FLASH_ATTN_EXT_VEC + 21)]];
- constant int32_t FC_flash_attn_ext_vec_nsg [[function_constant(FC_FLASH_ATTN_EXT_VEC + 22)]];
- constant int32_t FC_flash_attn_ext_vec_nwg [[function_constant(FC_FLASH_ATTN_EXT_VEC + 23)]];
- template<
- typename q4_t, // query types in shared memory
- typename k4_t, // key types in shared memory
- typename v4_t, // value types in shared memory
- typename qk_t, // Q*K types
- typename s_t, // soft-max types
- typename s4_t,
- typename o4_t, // attention accumulation types
- typename kd4_t, // key type in device memory
- short nl_k,
- void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
- typename vd4_t, // value type in device memory
- short nl_v,
- void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
- short DK, // K head size
- short DV, // V head size
- short NE, // head elements per thread
- short Q, // queries per threadgroup
- short C, // cache items per threadgroup
- short NSG> // number of simd groups
- void kernel_flash_attn_ext_vec_impl(
- constant ggml_metal_kargs_flash_attn_ext_vec & args,
- device const char * q,
- device const char * k,
- device const char * v,
- device const char * mask,
- device const char * sinks,
- device const char * pad,
- device char * dst,
- threadgroup half * shmem_f16 [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- static_assert(DK % 32 == 0, "DK must be divisible by 32");
- static_assert(DV % 32 == 0, "DV must be divisible by 32");
- #define NWG (FC_flash_attn_ext_vec_nwg)
- #define NS10 (FC_flash_attn_ext_vec_ns10)
- #define NS20 (FC_flash_attn_ext_vec_ns20)
- const short iwg = tgpig[2]%NWG;
- const ushort iq3 = tgpig[2]/NWG;
- const ushort iq2 = tgpig[1];
- const ushort iq1 = tgpig[0];
- constexpr short DK4 = DK/4;
- constexpr short DV4 = DV/4;
- constexpr short PK = PAD2(DK, 128);
- constexpr short PK4 = PK/4;
- constexpr short PV = PAD2(DV, 128);
- constexpr short PV4 = PV/4;
- constexpr short NW = N_SIMDWIDTH;
- constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
- constexpr short SH = 4*C; // shared memory per simdgroup
- static_assert(DK4 % NL == 0, "DK4 must be divisible by NL");
- static_assert(DV4 % NL == 0, "DV4 must be divisible by NL");
- const short T = PK + NSG*SH; // shared memory size per query in (half)
- //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data
- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t
- threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*PK); // scratch buffer for attention
- threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*PK); // same as above but in s4_t
- threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + Q*PK); // scratch buffer for mask
- threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + Q*T); // scratch buffer for the results
- // store the result for all queries in shared memory (the O matrix from the paper)
- so4 += tiisg;
- {
- q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03;
- const short ikv2 = iq2/(args.ne02/args.ne_12_2);
- const short ikv3 = iq3/(args.ne03/args.ne_12_3);
- k += ikv2*args.nb12 + ikv3*args.nb13;
- v += ikv2*args.nb22 + ikv3*args.nb23;
- }
- // load heads from Q to shared memory
- device const float4 * q4 = (device const float4 *) ((device const char *) q);
- for (short i = tiisg; i < PK4; i += NW) {
- if (iq1 < args.ne01 && i < DK4) {
- sq4[i] = (q4_t) q4[i];
- } else {
- sq4[i] = (q4_t) 0.0f;
- }
- }
- // zero out so
- for (short i = 0; i < DV4/NL; ++i) {
- so4[i*NL] = (o4_t) 0.0f;
- }
- // zero out shared memory SH
- for (short i = tiisg; i < SH/4; i += NW) {
- ss4[i] = (s4_t) 0.0f;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- {
- float S = 0.0f;
- float M = -FLT_MAX/2;
- // thread indices inside the simdgroup
- const short tx = tiisg%NL;
- const short ty = tiisg/NL;
- // pointer to the mask
- device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
- float slope = 1.0f;
- // ALiBi
- if (FC_flash_attn_ext_vec_has_bias) {
- const short h = iq2;
- const float base = h < args.n_head_log2 ? args.m0 : args.m1;
- const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
- slope = pow(base, exph);
- }
- // loop over the KV cache
- // each simdgroup handles blocks of Q rows and C columns
- for (int ic0 = iwg*NSG + sgitg; ; ic0 += NWG*NSG) {
- int ic = ic0*C;
- if (ic >= args.ne11) {
- break;
- }
- // the last partial chunk uses the pad buffer as source
- if (FC_flash_attn_ext_vec_has_kvpad && ic + C > args.ne11) {
- k = pad;
- v = k + args.nb11*C*args.ne_12_2*args.ne_12_3;
- mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3;
- const short ikv2 = iq2/(args.ne02/args.ne_12_2);
- const short ikv3 = iq3/(args.ne03/args.ne_12_3);
- k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C;
- v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;
- if (!FC_flash_attn_ext_vec_has_mask) {
- if (ic + tiisg >= args.ne11) {
- sm[tiisg] = -MAXHALF;
- }
- } else {
- pm = (device const half *) (mask) +
- iq1*C +
- (iq2%args.ne32)*(C*args.ne31) +
- (iq3%args.ne33)*(C*args.ne31*args.ne32);
- }
- ic = 0;
- }
- if (FC_flash_attn_ext_vec_has_mask) {
- sm[tiisg] = pm[ic + tiisg];
- }
- // skip -INF blocks
- if (simd_max(sm[tiisg]) == -INFINITY) {
- continue;
- }
- // Q*K^T
- {
- device const k4_t * pk4 = (device const k4_t *) (k + ic*args.nb11);
- threadgroup const q4_t * pq4 = sq4;
- pk4 += ty*NS10/4 + tx;
- pq4 += tx;
- qk_t mqk[C/NE] = { [ 0 ... C/NE - 1] = 0.0f };
- // each simdgroup processes 1 query and NE (NW/NL) cache elements
- FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) {
- if (is_same<kd4_t, k4_t>::value) {
- FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) {
- mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 + ii*NL], (float4) pq4[ii*NL]);
- }
- } else {
- device const kd4_t * pk = (device const kd4_t *) (k + ((ic + NE*cc + ty)*args.nb11));
- k4_t mk;
- FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) {
- const short i = ii*NL + tx;
- deq_k_t4(pk + i/nl_k, i%nl_k, mk);
- mqk[cc] += dot((float4) mk, (float4) sq4[i]);
- }
- }
- if (NE == 1) {
- mqk[cc] = simd_sum(mqk[cc]);
- } else {
- // simdgroup reduce (NE = 4)
- // [ 0 .. 7] -> [ 0]
- // [ 8 .. 15] -> [ 8]
- // [16 .. 23] -> [16]
- // [24 .. 31] -> [24]
- if (NE <= 1) {
- mqk[cc] += simd_shuffle_down(mqk[cc], 16);
- }
- if (NE <= 2) {
- mqk[cc] += simd_shuffle_down(mqk[cc], 8);
- }
- if (NE <= 4) {
- mqk[cc] += simd_shuffle_down(mqk[cc], 4);
- }
- if (NE <= 8) {
- mqk[cc] += simd_shuffle_down(mqk[cc], 2);
- }
- if (NE <= 16) {
- mqk[cc] += simd_shuffle_down(mqk[cc], 1);
- }
- // broadcast
- mqk[cc] = simd_shuffle(mqk[cc], NL*ty);
- }
- }
- if (FC_flash_attn_ext_vec_has_mask &&
- !FC_flash_attn_ext_vec_has_scap &&
- !FC_flash_attn_ext_vec_has_bias) {
- ss[NE*tx + ty] = fma(mqk[tx], args.scale, (qk_t) sm[NE*tx + ty]);
- } else {
- mqk[tx] *= args.scale;
- if (FC_flash_attn_ext_vec_has_scap) {
- mqk[tx] = args.logit_softcap*precise::tanh(mqk[tx]);
- }
- if (FC_flash_attn_ext_vec_has_bias) {
- mqk[tx] += (qk_t) sm[NE*tx + ty]*slope;
- } else {
- mqk[tx] += (qk_t) sm[NE*tx + ty];
- }
- ss[NE*tx + ty] = mqk[tx];
- }
- }
- simdgroup_barrier(mem_flags::mem_threadgroup);
- // online softmax
- {
- const float m = M;
- const float s = ss[tiisg];
- M = simd_max(max(M, s));
- const float ms = exp(m - M);
- const float vs = exp(s - M);
- S = S*ms + simd_sum(vs);
- // the P matrix from the paper (Q rows, C columns)
- ss[tiisg] = vs;
- // O = diag(ms)*O
- if ((DV4/NL % NW == 0) || ty == 0) {
- FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
- so4[ii*NL] *= ms;
- }
- }
- }
- simdgroup_barrier(mem_flags::mem_threadgroup);
- // O = O + (Q*K^T)*V
- {
- o4_t lo[DV4/NL];
- FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
- lo[ii] = 0.0f;
- }
- if (is_same<vd4_t, v4_t>::value) {
- device const v4_t * pv4 = (device const v4_t *) (v + ic*args.nb21);
- pv4 += ty*NS20/4 + tx;
- const auto sst = ss + ty;
- FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) {
- FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
- lo[ii] += o4_t(float4(pv4[cc*NE*NS20/4 + ii*NL])*float4(sst[cc*NE]));
- }
- }
- } else {
- FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) {
- device const vd4_t * pv4 = (device const vd4_t *) (v + ((ic + NE*cc + ty)*args.nb21));
- FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
- const short i = ii*NL + tx;
- v4_t mv;
- deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);
- lo[ii] += o4_t(float4(mv)*float4(ss[NE*cc + ty]));
- }
- }
- }
- FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
- if (NE > 1) {
- lo[ii][0] += simd_shuffle_down(lo[ii][0], 16);
- lo[ii][1] += simd_shuffle_down(lo[ii][1], 16);
- lo[ii][2] += simd_shuffle_down(lo[ii][2], 16);
- lo[ii][3] += simd_shuffle_down(lo[ii][3], 16);
- }
- if (NE > 2) {
- lo[ii][0] += simd_shuffle_down(lo[ii][0], 8);
- lo[ii][1] += simd_shuffle_down(lo[ii][1], 8);
- lo[ii][2] += simd_shuffle_down(lo[ii][2], 8);
- lo[ii][3] += simd_shuffle_down(lo[ii][3], 8);
- }
- if (NE > 4) {
- lo[ii][0] += simd_shuffle_down(lo[ii][0], 4);
- lo[ii][1] += simd_shuffle_down(lo[ii][1], 4);
- lo[ii][2] += simd_shuffle_down(lo[ii][2], 4);
- lo[ii][3] += simd_shuffle_down(lo[ii][3], 4);
- }
- if (NE > 8) {
- lo[ii][0] += simd_shuffle_down(lo[ii][0], 2);
- lo[ii][1] += simd_shuffle_down(lo[ii][1], 2);
- lo[ii][2] += simd_shuffle_down(lo[ii][2], 2);
- lo[ii][3] += simd_shuffle_down(lo[ii][3], 2);
- }
- if (NE > 16) {
- lo[ii][0] += simd_shuffle_down(lo[ii][0], 1);
- lo[ii][1] += simd_shuffle_down(lo[ii][1], 1);
- lo[ii][2] += simd_shuffle_down(lo[ii][2], 1);
- lo[ii][3] += simd_shuffle_down(lo[ii][3], 1);
- }
- }
- if ((DV4/NL % NW == 0) || ty == 0) {
- FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
- so4[ii*NL] += lo[ii];
- }
- }
- }
- }
- if (FC_flash_attn_ext_vec_has_sinks && sgitg == 0 && iwg == 0) {
- const float m = M;
- const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
- M = simd_max(max(M, s));
- const float ms = exp(m - M);
- const float vs = exp(s - M);
- S = S*ms + simd_sum(vs);
- if ((DV4/NL % NW == 0) || ty == 0) {
- FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
- so4[ii*NL] *= ms;
- }
- }
- }
- // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
- if (tiisg == 0) {
- ss[0] = (s_t) S;
- ss[1] = (s_t) M;
- }
- }
- so4 -= tiisg;
- threadgroup_barrier(mem_flags::mem_threadgroup);
- // parallel reduce
- for (short r = NSG/2; r > 0; r >>= 1) {
- if (sgitg < r) {
- const float S0 = ss[ 0];
- const float S1 = ss[r*(SH/2) + 0];
- const float M0 = ss[ 1];
- const float M1 = ss[r*(SH/2) + 1];
- const float M = max(M0, M1);
- const float ms0 = exp(M0 - M);
- const float ms1 = exp(M1 - M);
- const float S = S0*ms0 + S1*ms1;
- if (tiisg == 0) {
- ss[0] = S;
- ss[1] = M;
- }
- // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
- for (short i = tiisg; i < DV4; i += NW) {
- so4[i] = so4[i]*ms0 + so4[i + r*PV4]*ms1;
- }
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- }
- // final rescale with 1/S and store to global memory
- if (sgitg == 0) {
- const int64_t nrows = args.ne3*args.ne2*args.ne1;
- const int64_t rid = iq3*args.ne2*args.ne1 + iq2 + iq1*args.ne1;
- device float4 * dst4 = (device float4 *) dst;
- device float * dst1 = (device float *) dst + nrows*DV*NWG; // the S and M are stored after the results
- const float S = NWG == 1 ? (ss[0] == 0.0f ? 0.0f : 1.0f/ss[0]) : 1.0f;
- // interleave the workgroup data
- for (short i = tiisg; i < DV4; i += NW) {
- dst4[rid*DV4*NWG + NWG*i + iwg] = (float4) so4[i]*S;
- }
- // store S and M
- if (NWG > 1) {
- if (tiisg == 0) {
- dst1[rid*(2*NWG) + 2*iwg + 0] = ss[0];
- dst1[rid*(2*NWG) + 2*iwg + 1] = ss[1];
- }
- }
- }
- #undef NWG
- #undef NS10
- #undef NS20
- }
- template<
- typename q4_t, // query types in shared memory
- typename k4_t, // key types in shared memory
- typename v4_t, // value types in shared memory
- typename qk_t, // Q*K types
- typename s_t, // soft-max types
- typename s4_t,
- typename o4_t, // attention accumulation types
- typename kd4_t, // key type in device memory
- short nl_k,
- void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
- typename vd4_t, // value type in device memory
- short nl_v,
- void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
- short DK, // K head size
- short DV, // V head size
- short NE = 4, // head elements per thread
- short Q = OP_FLASH_ATTN_EXT_VEC_NQPTG, // queries per threadgroup
- short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup
- kernel void kernel_flash_attn_ext_vec(
- constant ggml_metal_kargs_flash_attn_ext_vec & args,
- device const char * q,
- device const char * k,
- device const char * v,
- device const char * mask,
- device const char * sinks,
- device const char * pad,
- device char * dst,
- threadgroup half * shmem_f16 [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- #define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C
- #define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg
- switch (FC_flash_attn_ext_vec_nsg) {
- // note: disabled cases to reduce library load time
- case 1: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 1>(FWD_ARGS); break;
- case 2: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 2>(FWD_ARGS); break;
- case 4: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 4>(FWD_ARGS); break;
- //case 8: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 8>(FWD_ARGS); break;
- //case 16: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 16>(FWD_ARGS); break;
- //case 32: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 32>(FWD_ARGS); break;
- }
- #undef FWD_TMPL
- #undef FWD_ARGS
- }
- // note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem
- // in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
- //
- #define FA_TYPES \
- half4, \
- half4, \
- half4, \
- float, \
- float, float4, \
- float4
- #define FA_TYPES_F32 \
- half4, \
- float4, \
- float4, \
- float, \
- float, float4, \
- float4
- typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
- template [[host_name("kernel_flash_attn_ext_vec_f32_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 32, 32, 4>;
- template [[host_name("kernel_flash_attn_ext_vec_f16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 32, 32, 4>;
- #if defined(GGML_METAL_HAS_BF16)
- template [[host_name("kernel_flash_attn_ext_vec_bf16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 32, 32, 4>;
- #endif
- template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 32, 32, 4>;
- template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 32, 32, 4>;
- template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 32, 32, 4>;
- template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 32, 32, 4>;
- template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 32, 32, 4>;
- template [[host_name("kernel_flash_attn_ext_vec_f32_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 64, 64, 2>;
- template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 2>;
- #if defined(GGML_METAL_HAS_BF16)
- template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 2>;
- #endif
- template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 64, 64, 2>;
- template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 64, 64, 2>;
- template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 64, 64, 2>;
- template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 2>;
- template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 2>;
- template [[host_name("kernel_flash_attn_ext_vec_f32_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 96, 96, 4>;
- template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
- #if defined(GGML_METAL_HAS_BF16)
- template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>;
- #endif
- template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 96, 96, 4>;
- template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 96, 96, 4>;
- template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 96, 96, 4>;
- template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>;
- template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>;
- template [[host_name("kernel_flash_attn_ext_vec_f32_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 128, 128, 1>;
- template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 1>;
- #if defined(GGML_METAL_HAS_BF16)
- template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 1>;
- #endif
- template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 128, 128, 1>;
- template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 128, 128, 1>;
- template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 128, 128, 1>;
- template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 1>;
- template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 1>;
- template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 192, 2>;
- template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 2>;
- #if defined(GGML_METAL_HAS_BF16)
- template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 2>;
- #endif
- template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 192, 2>;
- template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 192, 2>;
- template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 192, 2>;
- template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 2>;
- template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 2>;
- template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 128, 2>;
- template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 2>;
- #if defined(GGML_METAL_HAS_BF16)
- template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 2>;
- #endif
- template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 128, 2>;
- template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 128, 2>;
- template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 128, 2>;
- template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 2>;
- template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 2>;
- template [[host_name("kernel_flash_attn_ext_vec_f32_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 256, 256, 1>;
- template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 1>;
- #if defined(GGML_METAL_HAS_BF16)
- template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 1>;
- #endif
- template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 256, 256, 1>;
- template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 256, 256, 1>;
- template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 256, 256, 1>;
- template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>;
- template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>;
- template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 576, 512, 2>;
- template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
- #if defined(GGML_METAL_HAS_BF16)
- template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>;
- #endif
- template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 576, 512, 2>;
- template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 576, 512, 2>;
- template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 576, 512, 2>;
- template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 576, 512, 2>;
- template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 576, 512, 2>;
- #undef FA_TYPES
- #undef FA_TYPES_F32
- constant int32_t FC_flash_attn_ext_vec_reduce_DV [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 0)]];
- constant int32_t FC_flash_attn_ext_vec_reduce_NWG [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 1)]];
- kernel void kernel_flash_attn_ext_vec_reduce(
- constant ggml_metal_kargs_flash_attn_ext_vec_reduce & args,
- device const char * htmp,
- device char * dst,
- uint tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- #define NWG (FC_flash_attn_ext_vec_reduce_NWG)
- #define DV (FC_flash_attn_ext_vec_reduce_DV)
- const uint64_t rid = tgpig;
- const short iwg = tiisg;
- device const float * ss = (device const float *) htmp + (uint64_t)args.nrows*DV*NWG;
- float S = ss[rid*(2*NWG) + 2*iwg + 0];
- float M = ss[rid*(2*NWG) + 2*iwg + 1];
- const float m = simd_max(M);
- const float ms = exp(M - m);
- S = simd_sum(S*ms);
- S = S == 0.0f ? 0.0f : 1.0f/S;
- const short DV4 = DV/4;
- device const float4 * htmp4 = (device const float4 *) htmp + rid*DV4*NWG;
- device float4 * dst4 = (device float4 *) dst + rid*DV4;
- for (short i = sgitg; i < DV4; i += NWG) {
- const float4 v = simd_sum(htmp4[i*NWG + iwg]*ms);
- if (iwg == 0) {
- dst4[i] = v*S;
- }
- }
- #undef NWG
- #undef DV
- }
- template<typename T0, typename T1>
- kernel void kernel_cpy_t_t(
- constant ggml_metal_kargs_cpy & args,
- device const char * src0,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiitg[[thread_index_in_threadgroup]],
- ushort3 ntg[[threads_per_threadgroup]]) {
- const int i03 = tgpig[2];
- const int i02 = tgpig[1];
- const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
- const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
- const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
- const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
- const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
- const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
- const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
- device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
- for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) {
- device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
- dst_data[i00] = (T1) src[0];
- break;
- }
- }
- typedef decltype(kernel_cpy_t_t<float, float>) kernel_cpy_t;
- template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, float>;
- template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, half>;
- template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, int32_t>;
- template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, float>;
- template [[host_name("kernel_cpy_i32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, int32_t>;
- #if defined(GGML_METAL_HAS_BF16)
- template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, bfloat>;
- #endif
- template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<half, float>;
- template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<half, half>;
- #if defined(GGML_METAL_HAS_BF16)
- template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, float>;
- template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, bfloat>;
- #endif
- template<short QK,
- typename block_q,
- void (*quantize_func)(device const float *, device block_q &)>
- kernel void kernel_cpy_f32_q(
- constant ggml_metal_kargs_cpy & args,
- device const char * src0,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiitg[[thread_index_in_threadgroup]],
- ushort3 ntg[[threads_per_threadgroup]]) {
- const int i03 = tgpig[2];
- const int i02 = tgpig[1];
- const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
- const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
- const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
- const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
- const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
- const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
- const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
- device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
- for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
- device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00);
- quantize_func(src, dst_data[i00]);
- break;
- }
- }
- typedef decltype(kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>) cpy_f_q_t;
- template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>;
- template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0, block_q4_0, quantize_q4_0>;
- template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1, block_q4_1, quantize_q4_1>;
- template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0, block_q5_0, quantize_q5_0>;
- template [[host_name("kernel_cpy_f32_q5_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_1, block_q5_1, quantize_q5_1>;
- template [[host_name("kernel_cpy_f32_iq4_nl")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_NL, block_iq4_nl, quantize_iq4_nl>;
- template<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
- kernel void kernel_cpy_q_f32(
- constant ggml_metal_kargs_cpy & args,
- device const char * src0,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiitg[[thread_index_in_threadgroup]],
- ushort3 ntg[[threads_per_threadgroup]]) {
- const int i03 = tgpig[2];
- const int i02 = tgpig[1];
- const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
- const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
- const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
- const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
- const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
- const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
- const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
- device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
- device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
- for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
- T4x4 temp;
- dequantize_func(src_data + i00/nl, i00%nl, temp);
- dst_data[i00] = temp;
- break;
- }
- }
- typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
- template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
- template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
- template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
- template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
- template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
- template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
- template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
- template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
- template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_1, 2, dequantize_q5_1>;
- template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q8_0, 2, dequantize_q8_0>;
- kernel void kernel_concat(
- constant ggml_metal_kargs_concat & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort3 tpitg[[thread_position_in_threadgroup]],
- ushort3 ntg[[threads_per_threadgroup]]) {
- const int i3 = tgpig.z;
- const int i2 = tgpig.y;
- const int i1 = tgpig.x;
- int o[4] = {0, 0, 0, 0};
- o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03));
- device const float * x;
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
- if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
- x = (device const float *)(src0 + (i3 )*args.nb03 + (i2 )*args.nb02 + (i1 )*args.nb01 + (i0 )*args.nb00);
- } else {
- x = (device const float *)(src1 + (i3 - o[3])*args.nb13 + (i2 - o[2])*args.nb12 + (i1 - o[1])*args.nb11 + (i0 - o[0])*args.nb10);
- }
- device float * y = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
- *y = *x;
- }
- }
- template<int nr0, typename args_t>
- void kernel_mul_mv_q2_K_f32_impl(
- args_t args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem,
- uint3 tgpig,
- ushort tiisg,
- ushort sgitg) {
- const short NSG = FC_mul_mv_nsg;
- const int nb = args.ne00/QK_K;
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
- const int first_row = (r0 * NSG + sgitg) * nr0;
- const uint i12 = im%args.ne12;
- const uint i13 = im/args.ne12;
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0);
- device const float * y = (device const float *) (src1 + offset1);
- float yl[32];
- float sumf[nr0]={0.f};
- const short ix = tiisg/8; // 0...3
- const short it = tiisg%8; // 0...7
- const short iq = it/4; // 0 or 1
- const short ir = it%4; // 0...3
- const short is = (8*ir)/16;// 0 or 1
- device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
- for (int ib = ix; ib < nb; ib += 4) {
- float4 sumy = {0.f, 0.f, 0.f, 0.f};
- for (short i = 0; i < 8; ++i) {
- yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
- yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
- yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
- yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
- }
- device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is;
- device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
- device const half * dh = &x[ib].d;
- for (short row = 0; row < nr0; row++) {
- float4 acc1 = {0.f, 0.f, 0.f, 0.f};
- float4 acc2 = {0.f, 0.f, 0.f, 0.f};
- for (int i = 0; i < 8; i += 2) {
- acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
- acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
- acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
- acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
- acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
- acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
- acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
- acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
- }
- float dall = dh[0];
- float dmin = dh[1] * 1.f/16.f;
- sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
- (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
- (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
- (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
- dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
- qs += args.nb01/2;
- sc += args.nb01;
- dh += args.nb01/2;
- }
- y4 += 4 * QK_K;
- }
- device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
- float sum_all = simd_sum(sumf[row]);
- if (tiisg == 0) {
- dst_f32[first_row + row] = sum_all;
- }
- }
- }
- [[host_name("kernel_mul_mv_q2_K_f32")]]
- kernel void kernel_mul_mv_q2_K_f32(
- constant ggml_metal_kargs_mul_mv & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q2_K_f32_impl<N_R0_Q2_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
- }
- template<int nr0, typename args_t>
- void kernel_mul_mv_q3_K_f32_impl(
- args_t args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem,
- uint3 tgpig,
- ushort tiisg,
- ushort sgitg) {
- const short NSG = FC_mul_mv_nsg;
- const int nb = args.ne00/QK_K;
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
- const int first_row = (r0 * NSG + sgitg) * nr0;
- const uint i12 = im%args.ne12;
- const uint i13 = im/args.ne12;
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0);
- device const float * yy = (device const float *) (src1 + offset1);
- float yl[32];
- //const uint16_t kmask1 = 0x3030;
- //const uint16_t kmask2 = 0x0f0f;
- const short tid = tiisg/4;
- const short ix = tiisg%4;
- const short ip = tid/4; // 0 or 1
- const short il = 2*((tid%4)/2); // 0 or 2
- const short ir = tid%2;
- const short l0 = 8*ir;
- // One would think that the Metal compiler would figure out that ip and il can only have
- // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
- // with these two tales.
- //
- // Possible masks for the high bit
- const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0
- {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2
- {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0
- {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
- // Possible masks for the low 2 bits
- const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
- const ushort4 hm = mm[2*ip + il/2];
- const short shift = 2*il;
- const float v1 = il == 0 ? 4.f : 64.f;
- const float v2 = 4.f * v1;
- const uint16_t s_shift1 = 4*ip;
- const uint16_t s_shift2 = s_shift1 + il;
- const short q_offset = 32*ip + l0;
- const short y_offset = 128*ip + 32*il + l0;
- device const float * y1 = yy + ix*QK_K + y_offset;
- uint32_t scales32, aux32;
- thread uint16_t * scales16 = (thread uint16_t *)&scales32;
- thread const int8_t * scales = (thread const int8_t *)&scales32;
- float sumf1[nr0] = {0.f};
- float sumf2[nr0] = {0.f};
- for (int i = ix; i < nb; i += 4) {
- for (short l = 0; l < 8; ++l) {
- yl[l+ 0] = y1[l+ 0];
- yl[l+ 8] = y1[l+16];
- yl[l+16] = y1[l+32];
- yl[l+24] = y1[l+48];
- }
- device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
- device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
- device const uint16_t * a = (device const uint16_t *)(x[i].scales);
- device const half * dh = &x[i].d;
- for (short row = 0; row < nr0; ++row) {
- const float d_all = (float)dh[0];
- scales16[0] = a[4];
- scales16[1] = a[5];
- aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
- scales16[0] = a[il+0];
- scales16[1] = a[il+1];
- scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
- float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
- for (short l = 0; l < 8; l += 2) {
- const int32_t qs = q[l/2];
- s1 += yl[l+0] * (qs & qm[il/2][0]);
- s2 += yl[l+1] * (qs & qm[il/2][1]);
- s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
- s4 += yl[l+16] * (qs & qm[il/2][2]);
- s5 += yl[l+17] * (qs & qm[il/2][3]);
- s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
- }
- float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
- float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
- sumf1[row] += d1 * (scales[0] - 32);
- sumf2[row] += d2 * (scales[2] - 32);
- s1 = s2 = s3 = s4 = s5 = s6 = 0;
- for (short l = 0; l < 8; l += 2) {
- const int32_t qs = q[l/2+8];
- s1 += yl[l+8] * (qs & qm[il/2][0]);
- s2 += yl[l+9] * (qs & qm[il/2][1]);
- s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
- s4 += yl[l+24] * (qs & qm[il/2][2]);
- s5 += yl[l+25] * (qs & qm[il/2][3]);
- s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
- }
- d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
- d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
- sumf1[row] += d1 * (scales[1] - 32);
- sumf2[row] += d2 * (scales[3] - 32);
- q += args.nb01/2;
- h += args.nb01/2;
- a += args.nb01/2;
- dh += args.nb01/2;
- }
- y1 += 4 * QK_K;
- }
- for (int row = 0; row < nr0; ++row) {
- const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
- sumf1[row] = simd_sum(sumf);
- }
- device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- if (tiisg == 0) {
- for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
- dst_f32[first_row + row] = sumf1[row];
- }
- }
- }
- [[host_name("kernel_mul_mv_q3_K_f32")]]
- kernel void kernel_mul_mv_q3_K_f32(
- constant ggml_metal_kargs_mul_mv & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q3_K_f32_impl<N_R0_Q3_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
- }
- template<int nr0, typename args_t>
- void kernel_mul_mv_q4_K_f32_impl(
- args_t args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem,
- uint3 tgpig,
- ushort tiisg,
- ushort sgitg) {
- const short NSG = FC_mul_mv_nsg;
- constexpr uint16_t kmask1 = 0x3f3f;
- constexpr uint16_t kmask2 = 0x0f0f;
- constexpr uint16_t kmask3 = 0xc0c0;
- const short ix = tiisg/8; // 0...3
- const short it = tiisg%8; // 0...7
- const short iq = it/4; // 0 or 1
- const short ir = it%4; // 0...3
- const int nb = args.ne00/QK_K;
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
- const int first_row = (r0 * NSG + sgitg) * nr0;
- const uint i12 = im%args.ne12;
- const uint i13 = im/args.ne12;
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0);
- device const float * y = (device const float *) (src1 + offset1);
- float yl[16];
- float yh[16];
- float sumf[nr0]={0.f};
- device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
- uint16_t sc16[4];
- thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
- for (int ib = ix; ib < nb; ib += 4) {
- float4 sumy = {0.f, 0.f, 0.f, 0.f};
- for (short i = 0; i < 8; ++i) {
- yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
- yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
- yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
- yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
- }
- device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;
- device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
- device const half * dh = &x[ib].d;
- for (short row = 0; row < nr0; row++) {
- sc16[0] = sc[0] & kmask1;
- sc16[1] = sc[2] & kmask1;
- sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
- sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
- device const uint16_t * q2 = q1 + 32;
- float4 acc1 = {0.f, 0.f, 0.f, 0.f};
- float4 acc2 = {0.f, 0.f, 0.f, 0.f};
- FOR_UNROLL (short i = 0; i < 4; ++i) {
- acc1[0] += yl[2*i + 0] * (q1[i] & 0x000F);
- acc1[1] += yl[2*i + 1] * (q1[i] & 0x0F00);
- acc1[2] += yl[2*i + 8] * (q1[i] & 0x00F0);
- acc1[3] += yl[2*i + 9] * (q1[i] & 0xF000);
- acc2[0] += yh[2*i + 0] * (q2[i] & 0x000F);
- acc2[1] += yh[2*i + 1] * (q2[i] & 0x0F00);
- acc2[2] += yh[2*i + 8] * (q2[i] & 0x00F0);
- acc2[3] += yh[2*i + 9] * (q2[i] & 0xF000);
- }
- sumf[row] += dh[0] * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
- (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
- (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
- (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
- dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
- q1 += args.nb01/2;
- sc += args.nb01/2;
- dh += args.nb01/2;
- }
- y4 += 4 * QK_K;
- }
- device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
- for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
- float sum_all = simd_sum(sumf[row]);
- if (tiisg == 0) {
- dst_f32[first_row + row] = sum_all;
- }
- }
- }
- [[host_name("kernel_mul_mv_q4_K_f32")]]
- kernel void kernel_mul_mv_q4_K_f32(
- constant ggml_metal_kargs_mul_mv & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q4_K_f32_impl<N_R0_Q4_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
- }
- template<int nr0, typename args_t>
- void kernel_mul_mv_q5_K_f32_impl(
- args_t args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem,
- uint3 tgpig,
- ushort tiisg,
- ushort sgitg) {
- const short NSG = FC_mul_mv_nsg;
- const int nb = args.ne00/QK_K;
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
- const int first_row = (r0 * NSG + sgitg) * nr0;
- const uint i12 = im%args.ne12;
- const uint i13 = im/args.ne12;
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0);
- device const float * yy = (device const float *) (src1 + offset1);
- float sumf[nr0]={0.f};
- float yl[16], yh[16];
- constexpr uint16_t kmask1 = 0x3f3f;
- constexpr uint16_t kmask2 = 0x0f0f;
- constexpr uint16_t kmask3 = 0xc0c0;
- const short tid = tiisg/4;
- const short ix = tiisg%4;
- const short iq = tid/4;
- const short ir = tid%4;
- const short l0 = 8*ir;
- const short q_offset = 32*iq + l0;
- const short y_offset = 64*iq + l0;
- const uint8_t hm1 = 1u << (2*iq);
- const uint8_t hm2 = hm1 << 1;
- const uint8_t hm3 = hm1 << 4;
- const uint8_t hm4 = hm2 << 4;
- uint16_t sc16[4];
- thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
- device const float * y1 = yy + ix*QK_K + y_offset;
- for (int i = ix; i < nb; i += 4) {
- device const uint8_t * q1 = x[i].qs + q_offset;
- device const uint8_t * qh = x[i].qh + l0;
- device const half * dh = &x[i].d;
- device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;
- device const float * y2 = y1 + 128;
- float4 sumy = {0.f, 0.f, 0.f, 0.f};
- for (short l = 0; l < 8; ++l) {
- yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
- yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
- yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
- yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
- }
- for (short row = 0; row < nr0; ++row) {
- device const uint8_t * q2 = q1 + 64;
- sc16[0] = a[0] & kmask1;
- sc16[1] = a[2] & kmask1;
- sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
- sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
- float4 acc1 = {0.f};
- float4 acc2 = {0.f};
- FOR_UNROLL (short l = 0; l < 8; ++l) {
- uint8_t h = qh[l];
- acc1[0] += yl[l+0] * (q1[l] & 0x0F);
- acc1[1] += yl[l+8] * (q1[l] & 0xF0);
- acc1[2] += yh[l+0] * (q2[l] & 0x0F);
- acc1[3] += yh[l+8] * (q2[l] & 0xF0);
- acc2[0] += h & hm1 ? yl[l+0] : 0.f;
- acc2[1] += h & hm2 ? yl[l+8] : 0.f;
- acc2[2] += h & hm3 ? yh[l+0] : 0.f;
- acc2[3] += h & hm4 ? yh[l+8] : 0.f;
- }
- sumf[row] += dh[0] * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
- sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
- sc8[4] * (acc1[2] + 16.f*acc2[2]) +
- sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
- dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
- q1 += args.nb01;
- qh += args.nb01;
- dh += args.nb01/2;
- a += args.nb01/2;
- }
- y1 += 4 * QK_K;
- }
- device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
- const float tot = simd_sum(sumf[row]);
- if (tiisg == 0) {
- dst_f32[first_row + row] = tot;
- }
- }
- }
- [[host_name("kernel_mul_mv_q5_K_f32")]]
- kernel void kernel_mul_mv_q5_K_f32(
- constant ggml_metal_kargs_mul_mv & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q5_K_f32_impl<N_R0_Q5_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
- }
- template<int nr0, typename args_t>
- void kernel_mul_mv_q6_K_f32_impl(
- args_t args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem,
- uint3 tgpig,
- ushort tiisg,
- ushort sgitg) {
- const short NSG = FC_mul_mv_nsg;
- constexpr uint8_t kmask1 = 0x03;
- constexpr uint8_t kmask2 = 0x0C;
- constexpr uint8_t kmask3 = 0x30;
- constexpr uint8_t kmask4 = 0xC0;
- const int nb = args.ne00/QK_K;
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
- const int first_row = (r0 * NSG + sgitg) * nr0;
- const uint i12 = im%args.ne12;
- const uint i13 = im/args.ne12;
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
- device const float * yy = (device const float *) (src1 + offset1);
- float sumf[nr0] = { 0.f };
- float yl[16];
- const short tid = tiisg/2;
- const short ix = tiisg%2;
- const short ip = tid/8; // 0 or 1
- const short il = tid%8;
- const short l0 = 4*il;
- const short is = 8*ip + l0/16;
- const short y_offset = 128*ip + l0;
- const short q_offset_l = 64*ip + l0;
- const short q_offset_h = 32*ip + l0;
- for (int i = ix; i < nb; i += 2) {
- device const uint8_t * q1 = x[i].ql + q_offset_l;
- device const uint8_t * q2 = q1 + 32;
- device const uint8_t * qh = x[i].qh + q_offset_h;
- device const int8_t * sc = x[i].scales + is;
- device const half * dh = &x[i].d;
- device const float * y = yy + i * QK_K + y_offset;
- for (short l = 0; l < 4; ++l) {
- yl[4*l + 0] = y[l + 0];
- yl[4*l + 1] = y[l + 32];
- yl[4*l + 2] = y[l + 64];
- yl[4*l + 3] = y[l + 96];
- }
- for (short row = 0; row < nr0; ++row) {
- float4 sums = {0.f, 0.f, 0.f, 0.f};
- FOR_UNROLL (short l = 0; l < 4; ++l) {
- sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
- sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
- sums[2] += yl[4*l + 2] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
- sums[3] += yl[4*l + 3] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
- }
- sumf[row] += dh[0] * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
- q1 += args.nb01;
- q2 += args.nb01;
- qh += args.nb01;
- sc += args.nb01;
- dh += args.nb01/2;
- }
- }
- device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
- float sum_all = simd_sum(sumf[row]);
- if (tiisg == 0) {
- dst_f32[first_row + row] = sum_all;
- }
- }
- }
- [[host_name("kernel_mul_mv_q6_K_f32")]]
- kernel void kernel_mul_mv_q6_K_f32(
- constant ggml_metal_kargs_mul_mv & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q6_K_f32_impl<N_R0_Q6_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
- }
- // ======================= "True" 2-bit
- template<int nr0, typename args_t>
- void kernel_mul_mv_iq2_xxs_f32_impl(
- args_t args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem,
- uint3 tgpig,
- ushort tiisg,
- ushort sgitg) {
- const short NSG = FC_mul_mv_nsg;
- const int nb = args.ne00/QK_K;
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
- const int first_row = (r0 * NSG + sgitg) * nr0;
- const uint i12 = im%args.ne12;
- const uint i13 = im/args.ne12;
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0);
- device const float * y = (device const float *) (src1 + offset1);
- float yl[32];
- float sumf[nr0]={0.f};
- const int nb32 = nb * (QK_K / 32);
- threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem);
- threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256);
- {
- int nval = 4;
- int pos = (32*sgitg + tiisg)*nval;
- for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xxs_grid[pos + i];
- nval = 2;
- pos = (32*sgitg + tiisg)*nval;
- for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
- threadgroup_barrier(mem_flags::mem_threadgroup);
- }
- const int ix = tiisg;
- device const float * y4 = y + 32 * ix;
- for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
- for (short i = 0; i < 32; ++i) {
- yl[i] = y4[i];
- }
- const int ibl = ib32 / (QK_K / 32);
- const int ib = ib32 % (QK_K / 32);
- device const block_iq2_xxs * xr = x + ibl;
- device const uint16_t * q2 = xr->qs + 4 * ib;
- device const half * dh = &xr->d;
- for (short row = 0; row < nr0; row++) {
- const float db = dh[0];
- device const uint8_t * aux8 = (device const uint8_t *)q2;
- const uint32_t aux32 = q2[2] | (q2[3] << 16);
- const float d = db * (0.5f + (aux32 >> 28));
- float sum = 0;
- for (short l = 0; l < 4; ++l) {
- const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]);
- const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
- for (short j = 0; j < 8; ++j) {
- sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
- }
- }
- sumf[row] += d * sum;
- dh += args.nb01/2;
- q2 += args.nb01/2;
- }
- y4 += 32 * 32;
- }
- device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
- float sum_all = simd_sum(sumf[row]);
- if (tiisg == 0) {
- dst_f32[first_row + row] = sum_all * 0.25f;
- }
- }
- }
- [[host_name("kernel_mul_mv_iq2_xxs_f32")]]
- kernel void kernel_mul_mv_iq2_xxs_f32(
- constant ggml_metal_kargs_mul_mv & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
- }
- template<int nr0, typename args_t>
- void kernel_mul_mv_iq2_xs_f32_impl(
- args_t args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem,
- uint3 tgpig,
- ushort tiisg,
- ushort sgitg) {
- const short NSG = FC_mul_mv_nsg;
- const int nb = args.ne00/QK_K;
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
- const int first_row = (r0 * NSG + sgitg) * nr0;
- const uint i12 = im%args.ne12;
- const uint i13 = im/args.ne12;
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0);
- device const float * y = (device const float *) (src1 + offset1);
- float yl[32];
- float sumf[nr0]={0.f};
- const int nb32 = nb * (QK_K / 32);
- threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem);
- threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 512);
- {
- int nval = 8;
- int pos = (32*sgitg + tiisg)*nval;
- for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xs_grid[pos + i];
- nval = 2;
- pos = (32*sgitg + tiisg)*nval;
- for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
- threadgroup_barrier(mem_flags::mem_threadgroup);
- }
- const int ix = tiisg;
- device const float * y4 = y + 32 * ix;
- for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
- for (short i = 0; i < 32; ++i) {
- yl[i] = y4[i];
- }
- const int ibl = ib32 / (QK_K / 32);
- const int ib = ib32 % (QK_K / 32);
- device const block_iq2_xs * xr = x + ibl;
- device const uint16_t * q2 = xr->qs + 4 * ib;
- device const uint8_t * sc = xr->scales + ib;
- device const half * dh = &xr->d;
- for (short row = 0; row < nr0; row++) {
- const float db = dh[0];
- const uint8_t ls1 = sc[0] & 0xf;
- const uint8_t ls2 = sc[0] >> 4;
- const float d1 = db * (0.5f + ls1);
- const float d2 = db * (0.5f + ls2);
- float sum1 = 0, sum2 = 0;
- for (short l = 0; l < 2; ++l) {
- const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
- const uint8_t signs = ssigns[(q2[l] >> 9)];
- for (short j = 0; j < 8; ++j) {
- sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
- }
- }
- for (short l = 2; l < 4; ++l) {
- const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
- const uint8_t signs = ssigns[(q2[l] >> 9)];
- for (short j = 0; j < 8; ++j) {
- sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
- }
- }
- sumf[row] += d1 * sum1 + d2 * sum2;
- dh += args.nb01/2;
- q2 += args.nb01/2;
- sc += args.nb01;
- }
- y4 += 32 * 32;
- }
- device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
- float sum_all = simd_sum(sumf[row]);
- if (tiisg == 0) {
- dst_f32[first_row + row] = sum_all * 0.25f;
- }
- }
- }
- [[host_name("kernel_mul_mv_iq2_xs_f32")]]
- kernel void kernel_mul_mv_iq2_xs_f32(
- constant ggml_metal_kargs_mul_mv & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq2_xs_f32_impl<N_R0_IQ2_XS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
- }
- template<int nr0, typename args_t>
- void kernel_mul_mv_iq3_xxs_f32_impl(
- args_t args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem,
- uint3 tgpig,
- ushort tiisg,
- ushort sgitg) {
- const short NSG = FC_mul_mv_nsg;
- const int nb = args.ne00/QK_K;
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
- const int first_row = (r0 * NSG + sgitg) * nr0;
- const uint i12 = im%args.ne12;
- const uint i13 = im/args.ne12;
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0);
- device const float * y = (device const float *) (src1 + offset1);
- float yl[32];
- float sumf[nr0]={0.f};
- const int nb32 = nb * (QK_K / 32);
- threadgroup uint32_t * svalues = (threadgroup uint32_t *)(shmem);
- threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256);
- {
- int nval = 4;
- int pos = (32*sgitg + tiisg)*nval;
- for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3xxs_grid[pos + i];
- nval = 2;
- pos = (32*sgitg + tiisg)*nval;
- for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
- threadgroup_barrier(mem_flags::mem_threadgroup);
- }
- const int ix = tiisg;
- device const float * y4 = y + 32 * ix;
- for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
- for (short i = 0; i < 32; ++i) {
- yl[i] = y4[i];
- }
- const int ibl = ib32 / (QK_K / 32);
- const int ib = ib32 % (QK_K / 32);
- device const block_iq3_xxs * xr = x + ibl;
- device const uint8_t * q3 = xr->qs + 8 * ib;
- device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib;
- device const half * dh = &xr->d;
- for (short row = 0; row < nr0; row++) {
- const float db = dh[0];
- const uint32_t aux32 = gas[0] | (gas[1] << 16);
- const float d = db * (0.5f + (aux32 >> 28));
- float2 sum = {0};
- for (short l = 0; l < 4; ++l) {
- const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]);
- const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]);
- const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
- for (short j = 0; j < 4; ++j) {
- sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
- sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
- }
- }
- sumf[row] += d * (sum[0] + sum[1]);
- dh += args.nb01/2;
- q3 += args.nb01;
- gas += args.nb01/2;
- }
- y4 += 32 * 32;
- }
- device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
- float sum_all = simd_sum(sumf[row]);
- if (tiisg == 0) {
- dst_f32[first_row + row] = sum_all * 0.5f;
- }
- }
- }
- [[host_name("kernel_mul_mv_iq3_xxs_f32")]]
- kernel void kernel_mul_mv_iq3_xxs_f32(
- constant ggml_metal_kargs_mul_mv & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
- }
- template<int nr0, typename args_t>
- void kernel_mul_mv_iq3_s_f32_impl(
- args_t args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem,
- uint3 tgpig,
- ushort tiisg,
- ushort sgitg) {
- const short NSG = FC_mul_mv_nsg;
- const int nb = args.ne00/QK_K;
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
- const int first_row = (r0 * NSG + sgitg) * nr0;
- const uint i12 = im%args.ne12;
- const uint i13 = im/args.ne12;
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0);
- device const float * y = (device const float *) (src1 + offset1);
- float yl[32];
- float sumf[nr0]={0.f};
- const int nb32 = nb * (QK_K / 32);
- threadgroup uint32_t * svalues = (threadgroup uint32_t *) shmem;
- {
- int nval = 8;
- int pos = (32*sgitg + tiisg)*nval;
- for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3s_grid[pos + i];
- threadgroup_barrier(mem_flags::mem_threadgroup);
- }
- const int ix = tiisg;
- device const float * y4 = y + 32 * ix;
- for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
- for (short i = 0; i < 32; ++i) {
- yl[i] = y4[i];
- }
- const int ibl = ib32 / (QK_K / 32);
- const int ib = ib32 % (QK_K / 32);
- device const block_iq3_s * xr = x + ibl;
- device const uint8_t * qs = xr->qs + 8 * ib;
- device const uint8_t * qh = xr->qh + ib;
- device const uint8_t * sc = xr->scales + (ib/2);
- device const uint8_t * signs = xr->signs + 4 * ib;
- device const half * dh = &xr->d;
- for (short row = 0; row < nr0; row++) {
- const float db = dh[0];
- const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
- float2 sum = {0};
- for (short l = 0; l < 4; ++l) {
- const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues;
- const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues;
- const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
- const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
- for (short j = 0; j < 4; ++j) {
- sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
- sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
- }
- }
- sumf[row] += d * (sum[0] + sum[1]);
- dh += args.nb01/2;
- qs += args.nb01;
- qh += args.nb01;
- sc += args.nb01;
- signs += args.nb01;
- }
- y4 += 32 * 32;
- }
- device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
- float sum_all = simd_sum(sumf[row]);
- if (tiisg == 0) {
- dst_f32[first_row + row] = sum_all;
- }
- }
- }
- [[host_name("kernel_mul_mv_iq3_s_f32")]]
- kernel void kernel_mul_mv_iq3_s_f32(
- constant ggml_metal_kargs_mul_mv & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq3_s_f32_impl<N_R0_IQ3_S, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
- }
- template<int nr0, typename args_t>
- void kernel_mul_mv_iq2_s_f32_impl(
- args_t args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem,
- uint3 tgpig,
- ushort tiisg,
- ushort sgitg) {
- const short NSG = FC_mul_mv_nsg;
- const int nb = args.ne00/QK_K;
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
- const int first_row = (r0 * NSG + sgitg) * nr0;
- const uint i12 = im%args.ne12;
- const uint i13 = im/args.ne12;
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0);
- device const float * y = (device const float *) (src1 + offset1);
- float yl[32];
- float sumf[nr0]={0.f};
- const int nb32 = nb * (QK_K / 32);
- //threadgroup uint64_t * svalues = (threadgroup uint64_t *) shmem;
- //{
- // int nval = 32;
- // int pos = (32*sgitg + tiisg)*nval;
- // for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2s_grid[pos + i];
- // threadgroup_barrier(mem_flags::mem_threadgroup);
- //}
- const short ix = tiisg;
- device const float * y4 = y + 32 * ix;
- for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
- for (short i = 0; i < 32; ++i) {
- yl[i] = y4[i];
- }
- const int ibl = ib32 / (QK_K / 32);
- const int ib = ib32 % (QK_K / 32);
- device const block_iq2_s * xr = x + ibl;
- device const uint8_t * qs = xr->qs + 4 * ib;
- device const uint8_t * qh = xr->qh + ib;
- device const uint8_t * sc = xr->scales + ib;
- device const uint8_t * signs = qs + QK_K/8;
- device const half * dh = &xr->d;
- for (short row = 0; row < nr0; row++) {
- const float db = dh[0];
- const float d1 = db * (0.5f + (sc[0] & 0xf));
- const float d2 = db * (0.5f + (sc[0] >> 4));
- float2 sum = {0};
- for (short l = 0; l < 2; ++l) {
- //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
- //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
- constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
- constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
- for (short j = 0; j < 8; ++j) {
- sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]);
- sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]);
- }
- }
- sumf[row] += d1 * sum[0] + d2 * sum[1];
- dh += args.nb01/2;
- qs += args.nb01;
- qh += args.nb01;
- sc += args.nb01;
- signs += args.nb01;
- }
- y4 += 32 * 32;
- }
- device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
- float sum_all = simd_sum(sumf[row]);
- if (tiisg == 0) {
- dst_f32[first_row + row] = sum_all * 0.25f;
- }
- }
- }
- [[host_name("kernel_mul_mv_iq2_s_f32")]]
- kernel void kernel_mul_mv_iq2_s_f32(
- constant ggml_metal_kargs_mul_mv & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq2_s_f32_impl<N_R0_IQ2_S, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
- }
- template<int nr0, typename args_t>
- void kernel_mul_mv_iq1_s_f32_impl(
- args_t args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem,
- uint3 tgpig,
- ushort tiisg,
- ushort sgitg) {
- const short NSG = FC_mul_mv_nsg;
- const int nb = args.ne00/QK_K;
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
- const int first_row = (r0 * NSG + sgitg) * nr0;
- const uint i12 = im%args.ne12;
- const uint i13 = im/args.ne12;
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0);
- device const float * y = (device const float *) (src1 + offset1);
- float yl[32];
- float sumf[nr0]={0.f};
- const int nb32 = nb * (QK_K / 32);
- const short ix = tiisg;
- device const float * y4 = y + 32 * ix;
- for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
- float sumy = 0;
- for (short i = 0; i < 32; ++i) {
- yl[i] = y4[i];
- sumy += yl[i];
- }
- const int ibl = ib32 / (QK_K / 32);
- const int ib = ib32 % (QK_K / 32);
- device const block_iq1_s * xr = x + ibl;
- device const uint8_t * qs = xr->qs + 4 * ib;
- device const uint16_t * qh = xr->qh + ib;
- device const half * dh = &xr->d;
- for (short row = 0; row < nr0; row++) {
- constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
- constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700)));
- constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700)));
- constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700)));
- float sum = 0;
- for (short j = 0; j < 4; ++j) {
- sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
- + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4)
- + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
- + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
- }
- sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1);
- dh += args.nb01/2;
- qs += args.nb01;
- qh += args.nb01/2;
- }
- y4 += 32 * 32;
- }
- device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
- float sum_all = simd_sum(sumf[row]);
- if (tiisg == 0) {
- dst_f32[first_row + row] = sum_all;
- }
- }
- }
- [[host_name("kernel_mul_mv_iq1_s_f32")]]
- kernel void kernel_mul_mv_iq1_s_f32(
- constant ggml_metal_kargs_mul_mv & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq1_s_f32_impl<N_R0_IQ1_S, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
- }
- template<int nr0, typename args_t>
- void kernel_mul_mv_iq1_m_f32_impl(
- args_t args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem,
- uint3 tgpig,
- ushort tiisg,
- ushort sgitg) {
- const short NSG = FC_mul_mv_nsg;
- const int nb = args.ne00/QK_K;
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
- const int first_row = (r0 * NSG + sgitg) * nr0;
- const uint i12 = im%args.ne12;
- const uint i13 = im/args.ne12;
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0);
- device const float * y = (device const float *) (src1 + offset1);
- float yl[32];
- float sumf[nr0]={0.f};
- const int nb32 = nb * (QK_K / 32);
- const short ix = tiisg;
- device const float * y4 = y + 32 * ix;
- iq1m_scale_t scale;
- for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
- float4 sumy = {0.f};
- for (short i = 0; i < 8; ++i) {
- yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
- yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
- yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
- yl[i+24] = y4[i+24]; sumy[3] += yl[i+24];
- }
- const int ibl = ib32 / (QK_K / 32);
- const int ib = ib32 % (QK_K / 32);
- device const block_iq1_m * xr = x + ibl;
- device const uint8_t * qs = xr->qs + 4 * ib;
- device const uint8_t * qh = xr->qh + 2 * ib;
- device const uint16_t * sc = (device const uint16_t *)xr->scales;
- for (short row = 0; row < nr0; row++) {
- scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
- constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
- constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
- constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700)));
- constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
- float2 sum = {0.f};
- for (short j = 0; j < 4; ++j) {
- sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
- + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
- sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
- + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
- }
- const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
- const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
- sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
- (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
- sc += args.nb01/2;
- qs += args.nb01;
- qh += args.nb01;
- }
- y4 += 32 * 32;
- }
- device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
- float sum_all = simd_sum(sumf[row]);
- if (tiisg == 0) {
- dst_f32[first_row + row] = sum_all;
- }
- }
- }
- [[host_name("kernel_mul_mv_iq1_m_f32")]]
- kernel void kernel_mul_mv_iq1_m_f32(
- constant ggml_metal_kargs_mul_mv & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
- }
- template<int NR0, typename args_t>
- void kernel_mul_mv_iq4_nl_f32_impl(
- args_t args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem,
- uint3 tgpig,
- ushort tiisg,
- ushort sgitg) {
- const short NSG = FC_mul_mv_nsg;
- threadgroup float * shmem_f32 = (threadgroup float *) shmem;
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
- const int first_row = (r0 * NSG + sgitg) * NR0;
- const uint i12 = im%args.ne12;
- const uint i13 = im/args.ne12;
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
- device const float * y = (device const float *) (src1 + offset1);
- const int nb = args.ne00/QK4_NL;
- const int ns01 = args.nb01/args.nb00;
- const short ix = tiisg/2; // 0...15
- const short it = tiisg%2; // 0 or 1
- shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
- threadgroup_barrier(mem_flags::mem_threadgroup);
- float4 yl[4];
- float sumf[NR0]={0.f};
- device const float * yb = y + ix*QK4_NL + it*8;
- uint32_t aux32[2];
- thread const uint8_t * q8 = (thread const uint8_t *)aux32;
- float4 qf1, qf2;
- // [TAG_MUL_MV_WEIRD]
- for (int ib = ix; ib < nb && ib < ns01; ib += 16) {
- device const float4 * y4 = (device const float4 *)yb;
- yl[0] = y4[0];
- yl[1] = y4[4];
- yl[2] = y4[1];
- yl[3] = y4[5];
- for (short row = 0; row < NR0; row++) {
- device const block_iq4_nl & xb = x[row*ns01 + ib];
- device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
- float4 acc1 = {0.f}, acc2 = {0.f};
- aux32[0] = q4[0] | (q4[1] << 16);
- aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
- aux32[0] &= 0x0f0f0f0f;
- qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
- qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
- acc1 += yl[0] * qf1;
- acc2 += yl[1] * qf2;
- aux32[0] = q4[2] | (q4[3] << 16);
- aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
- aux32[0] &= 0x0f0f0f0f;
- qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
- qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
- acc1 += yl[2] * qf1;
- acc2 += yl[3] * qf2;
- acc1 += acc2;
- sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
- }
- yb += 16 * QK4_NL;
- }
- device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
- float sum_all = simd_sum(sumf[row]);
- if (tiisg == 0) {
- dst_f32[first_row + row] = sum_all;
- }
- }
- }
- [[host_name("kernel_mul_mv_iq4_nl_f32")]]
- kernel void kernel_mul_mv_iq4_nl_f32(
- constant ggml_metal_kargs_mul_mv & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
- }
- template<int NR0, typename args_t>
- void kernel_mul_mv_iq4_xs_f32_impl(
- args_t args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem,
- uint3 tgpig,
- ushort tiisg,
- ushort sgitg) {
- const short NSG = FC_mul_mv_nsg;
- threadgroup float * shmem_f32 = (threadgroup float *) shmem;
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
- const int first_row = (r0 * NSG + sgitg) * NR0;
- const uint i12 = im%args.ne12;
- const uint i13 = im/args.ne12;
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
- device const float * y = (device const float *) (src1 + offset1);
- const int nb = args.ne00/QK_K;
- const int ns01 = args.nb01/args.nb00;
- const short ix = tiisg/16; // 0 or 1
- const short it = tiisg%16; // 0...15
- const short ib = it/2;
- const short il = it%2;
- shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
- threadgroup_barrier(mem_flags::mem_threadgroup);
- float4 yl[4];
- float sumf[NR0]={0.f};
- device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
- uint32_t aux32[2];
- thread const uint8_t * q8 = (thread const uint8_t *)aux32;
- float4 qf1, qf2;
- // [TAG_MUL_MV_WEIRD]
- for (int ibl = ix; ibl < nb && ibl < ns01; ibl += 2) {
- device const float4 * y4 = (device const float4 *)yb;
- yl[0] = y4[0];
- yl[1] = y4[4];
- yl[2] = y4[1];
- yl[3] = y4[5];
- for (short row = 0; row < NR0; ++row) {
- device const block_iq4_xs & xb = x[row*ns01 + ibl];
- device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
- float4 acc1 = {0.f}, acc2 = {0.f};
- aux32[0] = (q4[0] ) & 0x0f0f0f0f;
- aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
- qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
- qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
- acc1 += yl[0] * qf1;
- acc2 += yl[1] * qf2;
- aux32[0] = (q4[1] ) & 0x0f0f0f0f;
- aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
- qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
- qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
- acc1 += yl[2] * qf1;
- acc2 += yl[3] * qf2;
- acc1 += acc2;
- const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
- sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
- }
- yb += 2 * QK_K;
- }
- device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
- float sum_all = simd_sum(sumf[row]);
- if (tiisg == 0) {
- dst_f32[first_row + row] = sum_all;
- }
- }
- }
- [[host_name("kernel_mul_mv_iq4_xs_f32")]]
- kernel void kernel_mul_mv_iq4_xs_f32(
- constant ggml_metal_kargs_mul_mv & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
- }
- template<int NR0, typename args_t>
- void kernel_mul_mv_mxfp4_f32_impl(
- args_t args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem,
- uint3 tgpig,
- ushort tiisg,
- ushort sgitg) {
- const short NSG = FC_mul_mv_nsg;
- threadgroup float * shmem_f32 = (threadgroup float *) shmem;
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
- const int first_row = (r0 * NSG + sgitg) * NR0;
- const uint i12 = im%args.ne12;
- const uint i13 = im/args.ne12;
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0);
- device const float * y = (device const float *) (src1 + offset1);
- const int nb = args.ne00/QK_MXFP4;
- const int ns01 = args.nb01/args.nb00; // this can be larger than nb for permuted src0 tensors
- const short ix = tiisg/2; // 0...15
- const short it = tiisg%2; // 0 or 1
- shmem_f32[tiisg] = kvalues_mxfp4_f[tiisg%16];
- threadgroup_barrier(mem_flags::mem_threadgroup);
- float4 yl[4];
- float sumf[NR0]={0.f};
- device const float * yb = y + ix*QK_MXFP4 + it*8;
- // note: just the check `ib < nb` is enough, but adding the redundant `&& ib < ns01` check makes the kernel a bit faster
- // no idea why that is - needs some deeper investigation [TAG_MUL_MV_WEIRD]
- for (int ib = ix; ib < nb && ib < ns01; ib += 16) {
- device const float4 * y4 = (device const float4 *) yb;
- yl[0] = y4[0];
- yl[1] = y4[4];
- yl[2] = y4[1];
- yl[3] = y4[5];
- FOR_UNROLL (short row = 0; row < NR0; row++) {
- device const block_mxfp4 & xb = x[row*ns01 + ib];
- device const uint8_t * q2 = (device const uint8_t *)(xb.qs + 8*it);
- float4 acc1 = yl[0]*float4(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]);
- float4 acc2 = yl[1]*float4(shmem_f32[q2[0] >> 4 ], shmem_f32[q2[1] >> 4 ], shmem_f32[q2[2] >> 4 ], shmem_f32[q2[3] >> 4 ]);
- float4 acc3 = yl[2]*float4(shmem_f32[q2[4] & 0x0F], shmem_f32[q2[5] & 0x0F], shmem_f32[q2[6] & 0x0F], shmem_f32[q2[7] & 0x0F]);
- float4 acc4 = yl[3]*float4(shmem_f32[q2[4] >> 4 ], shmem_f32[q2[5] >> 4 ], shmem_f32[q2[6] >> 4 ], shmem_f32[q2[7] >> 4 ]);
- acc1 = (acc1 + acc3) + (acc2 + acc4);
- sumf[row] += e8m0_to_fp32(xb.e) * ((acc1[0] + acc1[1]) + (acc1[2] + acc1[3]));
- }
- yb += 16 * QK_MXFP4;
- }
- device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
- float sum_all = simd_sum(sumf[row]);
- if (tiisg == 0) {
- dst_f32[first_row + row] = sum_all;
- }
- }
- }
- [[host_name("kernel_mul_mv_mxfp4_f32")]]
- kernel void kernel_mul_mv_mxfp4_f32(
- constant ggml_metal_kargs_mul_mv & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
- }
- template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
- kernel void kernel_get_rows_q(
- constant ggml_metal_kargs_get_rows & args,
- device const void * src0,
- device const void * src1,
- device void * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiitg[[thread_index_in_threadgroup]],
- ushort3 ntg [[threads_per_threadgroup]]) {
- const int32_t iw0 = tgpig.x/args.ne10;
- const int32_t i10 = tgpig.x%args.ne10;
- const int32_t i11 = tgpig.y;
- const int32_t i12 = tgpig.z;
- const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
- const int32_t i02 = i11;
- const int32_t i03 = i12;
- auto psrc = (device const block_q *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
- auto pdst = (device float4x4 *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
- for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
- float4x4 temp;
- dequantize_func(psrc + ind/nl, ind%nl, temp);
- pdst[ind] = temp;
- break;
- }
- }
- template<typename T0, typename T>
- kernel void kernel_get_rows_f(
- constant ggml_metal_kargs_get_rows & args,
- device const void * src0,
- device const void * src1,
- device void * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiitg[[thread_index_in_threadgroup]],
- ushort3 ntg [[threads_per_threadgroup]]) {
- const int32_t iw0 = tgpig.x/args.ne10;
- const int32_t i10 = tgpig.x%args.ne10;
- const int32_t i11 = tgpig.y;
- const int32_t i12 = tgpig.z;
- const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
- const int32_t i02 = i11;
- const int32_t i03 = i12;
- auto psrc = (const device T0 *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
- auto pdst = ( device T *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
- for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
- pdst[ind] = psrc[ind];
- break;
- }
- }
- template<typename TI, typename block_q, void (*quantize_func)(device const float *, device block_q &)>
- kernel void kernel_set_rows_q32(
- constant ggml_metal_kargs_set_rows & args,
- device const void * src0,
- device const void * src1,
- device float * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint3 tptg [[threads_per_threadgroup]]) {
- const int32_t i03 = tgpig.z;
- const int32_t i02 = tgpig.y;
- const int32_t i12 = i03%args.ne12;
- const int32_t i11 = i02%args.ne11;
- const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
- if (i01 >= args.ne01) {
- return;
- }
- const int32_t i10 = i01;
- const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
- device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
- const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
- for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
- quantize_func(src_row + 32*ind, dst_row[ind]);
- }
- }
- template<typename T, typename TI>
- kernel void kernel_set_rows_f(
- constant ggml_metal_kargs_set_rows & args,
- device const void * src0,
- device const void * src1,
- device float * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint3 tptg [[threads_per_threadgroup]]) {
- const int32_t i03 = tgpig.z;
- const int32_t i02 = tgpig.y;
- const int32_t i12 = i03%args.ne12;
- const int32_t i11 = i02%args.ne11;
- const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
- if (i01 >= args.ne01) {
- return;
- }
- const int32_t i10 = i01;
- const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
- device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
- const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
- for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
- dst_row[ind] = (T) src_row[ind];
- }
- }
- constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
- constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];
- // each block_q contains 16*nl weights
- template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
- kernel void kernel_mul_mm(
- constant ggml_metal_kargs_mul_mm & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiitg[[thread_index_in_threadgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- threadgroup S0 * sa = (threadgroup S0 *)(shmem);
- threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
- threadgroup float * sc = (threadgroup float *)(shmem);
- constexpr int NR0 = 64;
- constexpr int NR1 = 32;
- constexpr int NK = 32;
- constexpr int NL0 = NK/16;
- constexpr int NL1 = NK/8;
- const int im = tgpig.z;
- const int r0 = tgpig.y*NR0;
- const int r1 = tgpig.x*NR1;
- // if this block is of 64x32 shape or smaller
- const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;
- const short nr1 = (args.ne1 - r1 < NR1) ? (args.ne1 - r1) : NR1;
- // a thread shouldn't load data outside of the matrix
- const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
- const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
- const short il0 = (tiitg % NL0);
- short il = il0;
- const int i12 = im%args.ne12;
- const int i13 = im/args.ne12;
- const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- const short offset1 = il0/nl;
- device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
- const short iy = 8*(tiitg % NL1);
- device const T1 * y = (device const T1 *)(src1
- + args.nb13*i13
- + args.nb12*i12
- + args.nb11*(r1 + lr1)
- + args.nb10*iy);
- #ifndef GGML_METAL_HAS_TENSOR
- S0_8x8 ma[4];
- S1_8x8 mb[2];
- simdgroup_float8x8 mc[8];
- for (short i = 0; i < 8; i++){
- mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
- }
- #else
- auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0));
- auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));
- mpp::tensor_ops::matmul2d<
- mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
- execution_simdgroups<4>> mm;
- auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();
- #endif
- for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
- #ifndef GGML_METAL_HAS_TENSOR
- // load data and store to threadgroup memory
- if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
- threadgroup_barrier(mem_flags::mem_threadgroup);
- // no need for dequantization
- for (short i = 0; i < 16; i++) {
- const short sx = 2*il0 + i/8;
- const short sy = (tiitg/NL0)/8;
- //const short lx = i%8;
- //const short ly = (tiitg/NL0)%8;
- const short lx = (tiitg/NL0)%8;
- const short ly = i%8;
- const short ib = 8*sx + sy;
- *(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
- }
- } else {
- S0_4x4 temp_a;
- dequantize_func(x, il, temp_a);
- threadgroup_barrier(mem_flags::mem_threadgroup);
- FOR_UNROLL (short i = 0; i < 16; i++) {
- const short sx = 2*il0 + i/8;
- const short sy = (tiitg/NL0)/8;
- //const short lx = i%8;
- //const short ly = (tiitg/NL0)%8;
- const short lx = (tiitg/NL0)%8;
- const short ly = i%8;
- const short ib = 8*sx + sy;
- // NOTE: this is massively slower.. WTF?
- //sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];
- *(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4];
- }
- }
- if (FC_mul_mm_bc_inp) {
- for (short i = 0; i < 8; ++i) {
- const short sx = (tiitg%NL1);
- const short sy = (tiitg/NL1)/8;
- const short lx = i;
- const short ly = (tiitg/NL1)%8;
- //const short lx = (tiitg/NL1)%8;
- //const short ly = i;
- const short ib = 4*sx + sy;
- *(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
- }
- } else {
- const short sx = (tiitg%NL1);
- const short sy = (tiitg/NL1)/8;
- const short dx = sx;
- const short dy = sy;
- const short ly = (tiitg/NL1)%8;
- const short ib = 4*sx + sy;
- *(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
- }
- #else
- // load data and store to threadgroup memory
- if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
- threadgroup_barrier(mem_flags::mem_threadgroup);
- // no need for dequantization
- for (short i = 0; i < 16; i++) {
- const short sx = 2*il0 + i/8;
- const short sy = (tiitg/NL0)/8;
- const short lx = i%8;
- const short ly = (tiitg/NL0)%8;
- //const short lx = (tiitg/NL0)%8;
- //const short ly = i%8;
- *(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
- }
- } else {
- S0_4x4 temp_a;
- dequantize_func(x, il, temp_a);
- threadgroup_barrier(mem_flags::mem_threadgroup);
- FOR_UNROLL (short i = 0; i < 16; i++) {
- const short sx = 2*il0 + i/8;
- const short sy = (tiitg/NL0)/8;
- const short lx = i%8;
- const short ly = (tiitg/NL0)%8;
- //const short lx = (tiitg/NL0)%8;
- //const short ly = i%8;
- *(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4];
- }
- }
- if (FC_mul_mm_bc_inp) {
- for (short i = 0; i < 8; ++i) {
- const short sx = (tiitg%NL1);
- const short sy = (tiitg/NL1)/8;
- const short lx = i;
- const short ly = (tiitg/NL1)%8;
- //const short lx = (tiitg/NL1)%8;
- //const short ly = i;
- *(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
- }
- } else {
- const short sx = (tiitg%NL1);
- const short sy = (tiitg/NL1)/8;
- //const short lx = i;
- const short ly = (tiitg/NL1)%8;
- //const short lx = (tiitg/NL1)%8;
- //const short ly = i;
- *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));
- }
- #endif
- il = (il + 2 < nl) ? il + 2 : il % 2;
- x = (il < 2) ? x + (2 + nl - 1)/nl : x;
- y += NK;
- threadgroup_barrier(mem_flags::mem_threadgroup);
- #ifndef GGML_METAL_HAS_TENSOR
- // load matrices from threadgroup memory and conduct outer products
- threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
- threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
- FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
- simdgroup_barrier(mem_flags::mem_none);
- FOR_UNROLL (short i = 0; i < 4; i++) {
- simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
- }
- simdgroup_barrier(mem_flags::mem_none);
- FOR_UNROLL (short i = 0; i < 2; i++) {
- simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
- }
- simdgroup_barrier(mem_flags::mem_none);
- FOR_UNROLL (short i = 0; i < 8; i++){
- simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
- }
- lsma += 8*64;
- lsmb += 4*64;
- }
- #else
- auto sA = tA.slice(0, 0);
- auto sB = tB.slice(0, 0);
- mm.run(sB, sA, cT);
- #endif
- }
- if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) {
- // if no bounds checks on the output are needed, we can directly write to device memory
- #ifdef GGML_METAL_HAS_TENSOR
- device float * C = (device float *) dst +
- r0 + \
- r1 * args.ne0 + im*args.ne1*args.ne0;
- auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(args.ne0, NR1));
- cT.store(tC);
- #else
- device float * C = (device float *) dst +
- (r0 + 32*(sgitg & 1)) + \
- (r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
- for (short i = 0; i < 8; i++) {
- simdgroup_store(mc[i], C + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false);
- }
- #endif
- } else {
- // block is smaller than 64x32, we should avoid writing data outside of the matrix
- threadgroup_barrier(mem_flags::mem_threadgroup);
- threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
- #ifdef GGML_METAL_HAS_TENSOR
- auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));
- cT.store(tC);
- #else
- for (short i = 0; i < 8; i++) {
- simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
- }
- #endif
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (sgitg == 0) {
- for (int j = tiitg; j < nr1; j += NR1) {
- device float * D = (device float *) dst + r0 + (r1 + j)*args.ne0 + im*args.ne1*args.ne0;
- device float4 * D4 = (device float4 *) D;
- threadgroup float * C = temp_str + (j*NR0);
- threadgroup float4 * C4 = (threadgroup float4 *) C;
- int i = 0;
- for (; i < nr0/4; i++) {
- *(D4 + i) = *(C4 + i);
- }
- i *= 4;
- for (; i < nr0; i++) {
- *(D + i) = *(C + i);
- }
- }
- }
- }
- }
- template<short ne20> // n_expert_used
- kernel void kernel_mul_mm_id_map0(
- constant ggml_metal_kargs_mul_mm_id_map0 & args,
- device const char * src2,
- device char * htpe,
- device char * hids,
- threadgroup char * shmem [[threadgroup(0)]],
- ushort tpitg[[thread_position_in_threadgroup]],
- ushort ntg[[threads_per_threadgroup]]) {
- const short ide = tpitg; // expert id
- uint32_t n_all = 0;
- device int32_t * ids_i32 = (device int32_t *) hids + ide*args.ne21;
- for (int i21 = 0; i21 < args.ne21; i21 += ntg) { // n_tokens
- if (i21 + tpitg < args.ne21) {
- device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21);
- threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*ne20;
- #pragma unroll(ne20)
- for (short i20 = 0; i20 < ne20; i20++) {
- sids[i20] = src2_i32[i20];
- }
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- for (short t = 0; t < ntg; t++) {
- if (i21 + t >= args.ne21) {
- break;
- }
- threadgroup const uint16_t * sids = (threadgroup const uint16_t *) shmem + t*ne20;
- short sel = 0;
- #pragma unroll(ne20)
- for (short i20 = 0; i20 < ne20; i20++) {
- sel += (sids[i20] == ide)*(i20 + 1);
- }
- ids_i32[n_all] = (i21 + t)*ne20 + sel - 1;
- n_all += sel > 0;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- }
- device uint32_t * tpe_u32 = (device uint32_t *) (htpe);
- tpe_u32[ide] = n_all;
- }
- typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;
- template [[host_name("kernel_mul_mm_id_map0_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;
- template [[host_name("kernel_mul_mm_id_map0_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;
- template [[host_name("kernel_mul_mm_id_map0_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
- template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
- template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
- template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
- template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
- template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
- kernel void kernel_mul_mm_id(
- constant ggml_metal_kargs_mul_mm_id & args,
- device const char * src0,
- device const char * src1,
- device const char * htpe,
- device const char * hids,
- device char * dst,
- threadgroup char * shmem [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiitg[[thread_index_in_threadgroup]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- threadgroup S0 * sa = (threadgroup S0 *)(shmem);
- threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
- threadgroup float * sc = (threadgroup float *)(shmem);
- constexpr int NR0 = 64;
- constexpr int NR1 = 32;
- constexpr int NK = 32;
- constexpr int NL0 = NK/16;
- constexpr int NL1 = NK/8;
- const int im = tgpig.z; // expert
- const int r0 = tgpig.y*NR0;
- const int r1 = tgpig.x*NR1;
- device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);
- device const int32_t * ids_i32 = (device const int32_t *) (hids);
- const int32_t neh1 = tpe_u32[im];
- if (r1 >= neh1) {
- return;
- }
- // if this block is of 64x32 shape or smaller
- const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;
- const short nr1 = ( neh1 - r1 < NR1) ? ( neh1 - r1) : NR1;
- // a thread shouldn't load data outside of the matrix
- const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
- const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
- const short il0 = (tiitg % NL0);
- short il = il0;
- const int id = ids_i32[im*args.ne21 + r1 + lr1];
- const short i11 = (id % args.ne20) % args.ne11;
- const short i12 = (id / args.ne20);
- const short i13 = 0;
- const uint64_t offset0 = im*args.nb02 + i13*args.nb03;
- const short offset1 = il0/nl;
- device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
- const short iy = 8*(tiitg % NL1);
- device const T1 * y = (device const T1 *)(src1
- + args.nb13*i13
- + args.nb12*i12
- + args.nb11*i11
- + args.nb10*iy);
- #ifndef GGML_METAL_HAS_TENSOR
- S0_8x8 ma[4];
- S1_8x8 mb[2];
- simdgroup_float8x8 mc[8];
- for (short i = 0; i < 8; i++){
- mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
- }
- #else
- auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0));
- auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));
- mpp::tensor_ops::matmul2d<
- mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
- execution_simdgroups<4>> mm;
- auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();
- #endif
- for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
- #ifndef GGML_METAL_HAS_TENSOR
- // load data and store to threadgroup memory
- if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
- threadgroup_barrier(mem_flags::mem_threadgroup);
- // no need for dequantization
- for (short i = 0; i < 16; i++) {
- const short sx = 2*il0 + i/8;
- const short sy = (tiitg/NL0)/8;
- //const short lx = i%8;
- //const short ly = (tiitg/NL0)%8;
- const short lx = (tiitg/NL0)%8;
- const short ly = i%8;
- const short ib = 8*sx + sy;
- *(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
- }
- } else {
- S0_4x4 temp_a;
- dequantize_func(x, il, temp_a);
- threadgroup_barrier(mem_flags::mem_threadgroup);
- FOR_UNROLL (short i = 0; i < 16; i++) {
- const short sx = 2*il0 + i/8;
- const short sy = (tiitg/NL0)/8;
- //const short lx = i%8;
- //const short ly = (tiitg/NL0)%8;
- const short lx = (tiitg/NL0)%8;
- const short ly = i%8;
- const short ib = 8*sx + sy;
- // NOTE: this is massively slower.. WTF?
- //sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];
- *(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4];
- }
- }
- if (FC_mul_mm_bc_inp) {
- for (short i = 0; i < 8; ++i) {
- const short sx = (tiitg%NL1);
- const short sy = (tiitg/NL1)/8;
- const short lx = i;
- const short ly = (tiitg/NL1)%8;
- //const short lx = (tiitg/NL1)%8;
- //const short ly = i;
- const short ib = 4*sx + sy;
- *(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
- }
- } else {
- const short sx = (tiitg%NL1);
- const short sy = (tiitg/NL1)/8;
- const short dx = sx;
- const short dy = sy;
- const short ly = (tiitg/NL1)%8;
- const short ib = 4*sx + sy;
- *(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
- }
- #else
- // load data and store to threadgroup memory
- if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
- threadgroup_barrier(mem_flags::mem_threadgroup);
- // no need for dequantization
- for (short i = 0; i < 16; i++) {
- const short sx = 2*il0 + i/8;
- const short sy = (tiitg/NL0)/8;
- const short lx = i%8;
- const short ly = (tiitg/NL0)%8;
- //const short lx = (tiitg/NL0)%8;
- //const short ly = i%8;
- *(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
- }
- } else {
- S0_4x4 temp_a;
- dequantize_func(x, il, temp_a);
- threadgroup_barrier(mem_flags::mem_threadgroup);
- FOR_UNROLL (short i = 0; i < 16; i++) {
- const short sx = 2*il0 + i/8;
- const short sy = (tiitg/NL0)/8;
- const short lx = i%8;
- const short ly = (tiitg/NL0)%8;
- //const short lx = (tiitg/NL0)%8;
- //const short ly = i%8;
- *(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4];
- }
- }
- if (FC_mul_mm_bc_inp) {
- for (short i = 0; i < 8; ++i) {
- const short sx = (tiitg%NL1);
- const short sy = (tiitg/NL1)/8;
- const short lx = i;
- const short ly = (tiitg/NL1)%8;
- //const short lx = (tiitg/NL1)%8;
- //const short ly = i;
- *(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
- }
- } else {
- const short sx = (tiitg%NL1);
- const short sy = (tiitg/NL1)/8;
- //const short lx = i;
- const short ly = (tiitg/NL1)%8;
- //const short lx = (tiitg/NL1)%8;
- //const short ly = i;
- *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));
- }
- #endif
- il = (il + 2 < nl) ? il + 2 : il % 2;
- x = (il < 2) ? x + (2 + nl - 1)/nl : x;
- y += NK;
- threadgroup_barrier(mem_flags::mem_threadgroup);
- #ifndef GGML_METAL_HAS_TENSOR
- // load matrices from threadgroup memory and conduct outer products
- threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
- threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
- FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
- simdgroup_barrier(mem_flags::mem_none);
- FOR_UNROLL (short i = 0; i < 4; i++) {
- simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
- }
- simdgroup_barrier(mem_flags::mem_none);
- FOR_UNROLL (short i = 0; i < 2; i++) {
- simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
- }
- simdgroup_barrier(mem_flags::mem_none);
- FOR_UNROLL (short i = 0; i < 8; i++){
- simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
- }
- lsma += 8*64;
- lsmb += 4*64;
- }
- #else
- auto sA = tA.slice(0, 0);
- auto sB = tB.slice(0, 0);
- mm.run(sB, sA, cT);
- #endif
- }
- // block is smaller than 64x32, we should avoid writing data outside of the matrix
- threadgroup_barrier(mem_flags::mem_threadgroup);
- #ifdef GGML_METAL_HAS_TENSOR
- auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));
- cT.store(tC);
- #else
- threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
- for (short i = 0; i < 8; i++) {
- simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
- }
- #endif
- threadgroup_barrier(mem_flags::mem_threadgroup);
- for (short j = sgitg; j < nr1; j += 4) {
- const int id = ids_i32[im*args.ne21 + r1 + j];
- const short ide = id % args.ne20;
- const short idt = id / args.ne20;
- device float * D = (device float *) dst + r0 + ide*args.ne0 + idt*args.ne1*args.ne0;
- device float4 * D4 = (device float4 *) D;
- threadgroup float * C = (threadgroup float *) shmem + j*NR0;
- threadgroup float4 * C4 = (threadgroup float4 *) C;
- int i = tiisg;
- for (; i < nr0/4; i += 32) {
- *(D4 + i) = *(C4 + i);
- }
- i = (4*(nr0/4)) + tiisg;
- for (; i < nr0; i += 32) {
- *(D + i) = *(C + i);
- }
- }
- }
- #define QK_NL 16
- //
- // get rows
- //
- typedef decltype(kernel_get_rows_f<float, float>) get_rows_f_t;
- template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float, float>;
- template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half, float>;
- template [[host_name("kernel_get_rows_i32")]] kernel get_rows_f_t kernel_get_rows_f<int32_t, int32_t>;
- #if defined(GGML_METAL_HAS_BF16)
- template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat, float>;
- #endif
- typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
- template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
- template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
- template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
- template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
- template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
- template [[host_name("kernel_get_rows_mxfp4")]] kernel get_rows_q_t kernel_get_rows_q<block_mxfp4, 2, dequantize_mxfp4>;
- template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
- template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
- template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
- template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_K, QK_NL, dequantize_q5_K>;
- template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q6_K, QK_NL, dequantize_q6_K>;
- template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
- template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
- template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
- template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_s, QK_NL, dequantize_iq3_s>;
- template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_s, QK_NL, dequantize_iq2_s>;
- template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_s, QK_NL, dequantize_iq1_s>;
- template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_m, QK_NL, dequantize_iq1_m>;
- template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
- template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
- //
- // set rows
- //
- typedef decltype(kernel_set_rows_f<float, int64_t>) set_rows_f_t;
- template [[host_name("kernel_set_rows_f32_i64")]] kernel set_rows_f_t kernel_set_rows_f<float, int64_t>;
- template [[host_name("kernel_set_rows_f32_i32")]] kernel set_rows_f_t kernel_set_rows_f<float, int32_t>;
- template [[host_name("kernel_set_rows_f16_i64")]] kernel set_rows_f_t kernel_set_rows_f<half, int64_t>;
- template [[host_name("kernel_set_rows_f16_i32")]] kernel set_rows_f_t kernel_set_rows_f<half, int32_t>;
- #if defined(GGML_METAL_HAS_BF16)
- template [[host_name("kernel_set_rows_bf16_i64")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int64_t>;
- template [[host_name("kernel_set_rows_bf16_i32")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int32_t>;
- #endif
- typedef decltype(kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>) set_rows_q32_t;
- template [[host_name("kernel_set_rows_q8_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>;
- template [[host_name("kernel_set_rows_q8_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q8_0, quantize_q8_0>;
- template [[host_name("kernel_set_rows_q4_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_0, quantize_q4_0>;
- template [[host_name("kernel_set_rows_q4_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_0, quantize_q4_0>;
- template [[host_name("kernel_set_rows_q4_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_1, quantize_q4_1>;
- template [[host_name("kernel_set_rows_q4_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_1, quantize_q4_1>;
- template [[host_name("kernel_set_rows_q5_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_0, quantize_q5_0>;
- template [[host_name("kernel_set_rows_q5_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_0, quantize_q5_0>;
- template [[host_name("kernel_set_rows_q5_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_1, quantize_q5_1>;
- template [[host_name("kernel_set_rows_q5_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_1, quantize_q5_1>;
- template [[host_name("kernel_set_rows_iq4_nl_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_iq4_nl, quantize_iq4_nl>;
- template [[host_name("kernel_set_rows_iq4_nl_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_iq4_nl, quantize_iq4_nl>;
- //
- // matrix-matrix multiplication
- //
- typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_t;
- template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, float, float2x4>;
- #if defined(GGML_METAL_HAS_BF16)
- template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
- #endif
- template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
- #if defined(GGML_METAL_HAS_BF16)
- template [[host_name("kernel_mul_mm_bf16_f16")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, half, half2x4, simdgroup_half8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, half, half2x4>;
- #endif
- template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_q5_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_q8_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_mxfp4_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_q2_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_q3_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_q4_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_q5_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_q6_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_iq2_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_iq2_xs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_iq3_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_iq3_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_iq2_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_iq1_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_iq1_m_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_iq4_nl_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_iq4_xs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, half, half2x4>;
- //
- // indirect matrix-matrix multiplication
- //
- typedef decltype(kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_id;
- template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, float, float2x4>;
- #if defined(GGML_METAL_HAS_BF16)
- template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
- #endif
- template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_id_mxfp4_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, float, float2x4>;
- template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
- #if defined(GGML_METAL_HAS_BF16)
- template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, half, half2x4, simdgroup_half8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, half, half2x4>;
- #endif
- template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, half, half2x4>;
- template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, half, half2x4>;
- //
- // matrix-vector multiplication
- //
- typedef void (kernel_mul_mv_disp_t)(
- ggml_metal_kargs_mul_mv args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint3 tgpig,
- ushort tiisg);
- typedef void (kernel_mul_mv2_disp_t)(
- ggml_metal_kargs_mul_mv args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem,
- uint3 tgpig,
- ushort tiisg,
- ushort sgitg);
- template<kernel_mul_mv_disp_t disp_fn>
- void mmv_fn(
- ggml_metal_kargs_mul_mv args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem,
- uint3 tgpig,
- ushort tiitg,
- ushort tiisg,
- ushort sgitg) {
- disp_fn(args, src0, src1, dst, tgpig, tiisg);
- }
- template<kernel_mul_mv2_disp_t disp_fn>
- void mmv_fn(
- ggml_metal_kargs_mul_mv args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- threadgroup char * shmem,
- uint3 tgpig,
- ushort tiitg,
- ushort tiisg,
- ushort sgitg) {
- disp_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
- }
- typedef decltype(mmv_fn<kernel_mul_mv_t_t_disp<half, half, ggml_metal_kargs_mul_mv>>) mul_mv_disp_fn_t;
- template<mul_mv_disp_fn_t disp_fn>
- kernel void kernel_mul_mv_id(
- constant ggml_metal_kargs_mul_mv_id & args,
- device const char * src0s,
- device const char * src1,
- device char * dst,
- device const char * ids,
- threadgroup char * shmem [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiitg[[thread_index_in_threadgroup]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- const int iid1 = tgpig.z/args.nei0;
- const int idx = tgpig.z%args.nei0;
- tgpig.z = 0;
- const int32_t i02 = ((device const int32_t *) (ids + iid1*args.nbi1))[idx];
- const int64_t i11 = idx % args.ne11;
- const int64_t i12 = iid1;
- const int64_t i1 = idx;
- const int64_t i2 = i12;
- device const char * src0_cur = src0s + i02*args.nb02;
- device const char * src1_cur = src1 + i11*args.nb11 + i12*args.nb12;
- device char * dst_cur = dst + (i1*args.ne0 + i2*args.ne1*args.ne0)*sizeof(float);
- ggml_metal_kargs_mul_mv args0 = {
- /*.ne00 =*/ args.ne00,
- /*.ne01 =*/ args.ne01,
- /*.ne02 =*/ 1, // args.ne02,
- /*.nb00 =*/ args.nb00,
- /*.nb01 =*/ args.nb01,
- /*.nb02 =*/ args.nb02,
- /*.nb03 =*/ args.nb02, // args.ne02 == 1
- /*.ne10 =*/ args.ne10,
- /*.ne11 =*/ 1, // args.ne11,
- /*.ne12 =*/ 1, // args.ne12,
- /*.nb10 =*/ args.nb10,
- /*.nb11 =*/ args.nb11,
- /*.nb12 =*/ args.nb12,
- /*.nb13 =*/ args.nb12, // ne12 == 1
- /*.ne0 =*/ args.ne0,
- /*.ne1 =*/ 1, // args.ne1,
- /*.nr0 =*/ args.nr0,
- /*.r2 =*/ 1,
- /*.r3 =*/ 1,
- };
- disp_fn(
- args0,
- /* src0 */ src0_cur,
- /* src1 */ src1_cur,
- /* dst */ dst_cur,
- shmem,
- tgpig,
- tiitg,
- tiisg,
- sgitg);
- }
- typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<float, float>>>) kernel_mul_mv_id_t;
- typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<float, float4, float, float4>>>) kernel_mul_mv_id_4_t;
- template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<float, float>>>;
- template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<half, float>>>;
- #if defined(GGML_METAL_HAS_BF16)
- template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<bfloat, float>>>;
- #endif
- template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<float, float4, float, float4>>>;
- template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<half, half4, float, float4>>>;
- #if defined(GGML_METAL_HAS_BF16)
- template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<bfloat, bfloat4, float, float4>>>;
- #endif
- template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>;
- template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0>>>;
- template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1>>>;
- template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0>>>;
- template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1>>>;
- template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4>>>;
- template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl <N_R0_Q2_K>>>;
- template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl <N_R0_Q3_K>>>;
- template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl <N_R0_Q4_K>>>;
- template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl <N_R0_Q5_K>>>;
- template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl <N_R0_Q6_K>>>;
- template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl <N_R0_IQ1_S>>>;
- template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl <N_R0_IQ1_M>>>;
- template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS>>>;
- template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl <N_R0_IQ2_XS>>>;
- template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS>>>;
- template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl <N_R0_IQ3_S>>>;
- template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl <N_R0_IQ2_S>>>;
- template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL>>>;
- template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS>>>;
- kernel void kernel_pool_2d_max_f32(
- constant ggml_metal_kargs_pool_2d & args,
- device const float * src0,
- device float * dst,
- uint gid[[thread_position_in_grid]]) {
- if (gid >= args.np) {
- return;
- }
- const int idx = gid;
- const int I_HW = args.IH * args.IW;
- const int O_HW = args.OH * args.OW;
- const int nc = idx / O_HW;
- const int cur_oh = idx % O_HW / args.OW;
- const int cur_ow = idx % O_HW % args.OW;
- device const float * i_ptr = src0 + nc * I_HW;
- device float * o_ptr = dst + nc * O_HW;
- const int start_h = cur_oh * args.s1 - args.p1;
- const int bh = MAX(0, start_h);
- const int eh = MIN(args.IH, start_h + args.k1);
- const int start_w = cur_ow * args.s0 - args.p0;
- const int bw = MAX(0, start_w);
- const int ew = MIN(args.IW, start_w + args.k0);
- float res = -INFINITY;
- for (int i = bh; i < eh; i += 1) {
- for (int j = bw; j < ew; j += 1) {
- res = MAX(res, i_ptr[i * args.IW + j]);
- }
- }
- o_ptr[cur_oh * args.OW + cur_ow] = res;
- }
- kernel void kernel_pool_2d_avg_f32(
- constant ggml_metal_kargs_pool_2d & args,
- device const float * src0,
- device float * dst,
- uint gid[[thread_position_in_grid]]) {
- if (gid >= args.np) {
- return;
- }
- const int idx = gid;
- const int I_HW = args.IH * args.IW;
- const int O_HW = args.OH * args.OW;
- const int nc = idx / O_HW;
- const int cur_oh = idx % O_HW / args.OW;
- const int cur_ow = idx % O_HW % args.OW;
- device const float * i_ptr = src0 + nc * I_HW;
- device float * o_ptr = dst + nc * O_HW;
- const int start_h = cur_oh * args.s1 - args.p1;
- const int bh = MAX(0, start_h);
- const int eh = MIN(args.IH, start_h + args.k1);
- const int start_w = cur_ow * args.s0 - args.p0;
- const int bw = MAX(0, start_w);
- const int ew = MIN(args.IW, start_w + args.k0);
- // const float scale = 1. / ((eh - bh) * (ew - bw));
- const float scale = 1. / (args.k0 * args.k1);
- float res = 0;
- for (int i = bh; i < eh; i += 1) {
- for (int j = bw; j < ew; j += 1) {
- float cur = i_ptr[i * args.IW + j];
- res += cur * scale;
- }
- }
- o_ptr[cur_oh * args.OW + cur_ow] = res;
- }
- kernel void kernel_opt_step_adamw_f32(
- constant ggml_metal_kargs_opt_step_adamw & args,
- device float * x,
- device const float * g,
- device float * g_m,
- device float * g_v,
- device const float * pars,
- uint gid[[thread_position_in_grid]]) {
- if (gid >= args.np) {
- return;
- }
- const float alpha = pars[0];
- const float beta1 = pars[1];
- const float beta2 = pars[2];
- const float eps = pars[3];
- const float wd = pars[4];
- const float beta1h = pars[5];
- const float beta2h = pars[6];
- const float gi = g[gid];
- const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1);
- const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2);
- g_m[gid] = gmi;
- g_v[gid] = gvi;
- const float mh = gmi * beta1h;
- const float vh = sqrt(gvi * beta2h) + eps;
- x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh;
- }
- kernel void kernel_opt_step_sgd_f32(
- constant ggml_metal_kargs_opt_step_sgd & args,
- device float * x,
- device const float * g,
- device const float * pars,
- uint gid[[thread_position_in_grid]]) {
- if (gid >= args.np) {
- return;
- }
- x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
- }
|