ggml.c 236 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868386938703871387238733874387538763877387838793880388138823883388438853886388738883889389038913892389338943895389638973898389939003901390239033904390539063907390839093910391139123913391439153916391739183919392039213922392339243925392639273928392939303931393239333934393539363937393839393940394139423943394439453946394739483949395039513952395339543955395639573958395939603961396239633964396539663967396839693970397139723973397439753976397739783979398039813982398339843985398639873988398939903991399239933994399539963997399839994000400140024003400440054006400740084009401040114012401340144015401640174018401940204021402240234024402540264027402840294030403140324033403440354036403740384039404040414042404340444045404640474048404940504051405240534054405540564057405840594060406140624063406440654066406740684069407040714072407340744075407640774078407940804081408240834084408540864087408840894090409140924093409440954096409740984099410041014102410341044105410641074108410941104111411241134114411541164117411841194120412141224123412441254126412741284129413041314132413341344135413641374138413941404141414241434144414541464147414841494150415141524153415441554156415741584159416041614162416341644165416641674168416941704171417241734174417541764177417841794180418141824183418441854186418741884189419041914192419341944195419641974198419942004201420242034204420542064207420842094210421142124213421442154216421742184219422042214222422342244225422642274228422942304231423242334234423542364237423842394240424142424243424442454246424742484249425042514252425342544255425642574258425942604261426242634264426542664267426842694270427142724273427442754276427742784279428042814282428342844285428642874288428942904291429242934294429542964297429842994300430143024303430443054306430743084309431043114312431343144315431643174318431943204321432243234324432543264327432843294330433143324333433443354336433743384339434043414342434343444345434643474348434943504351435243534354435543564357435843594360436143624363436443654366436743684369437043714372437343744375437643774378437943804381438243834384438543864387438843894390439143924393439443954396439743984399440044014402440344044405440644074408440944104411441244134414441544164417441844194420442144224423442444254426442744284429443044314432443344344435443644374438443944404441444244434444444544464447444844494450445144524453445444554456445744584459446044614462446344644465446644674468446944704471447244734474447544764477447844794480448144824483448444854486448744884489449044914492449344944495449644974498449945004501450245034504450545064507450845094510451145124513451445154516451745184519452045214522452345244525452645274528452945304531453245334534453545364537453845394540454145424543454445454546454745484549455045514552455345544555455645574558455945604561456245634564456545664567456845694570457145724573457445754576457745784579458045814582458345844585458645874588458945904591459245934594459545964597459845994600460146024603460446054606460746084609461046114612461346144615461646174618461946204621462246234624462546264627462846294630463146324633463446354636463746384639464046414642464346444645464646474648464946504651465246534654465546564657465846594660466146624663466446654666466746684669467046714672467346744675467646774678467946804681468246834684468546864687468846894690469146924693469446954696469746984699470047014702470347044705470647074708470947104711471247134714471547164717471847194720472147224723472447254726472747284729473047314732473347344735473647374738473947404741474247434744474547464747474847494750475147524753475447554756475747584759476047614762476347644765476647674768476947704771477247734774477547764777477847794780478147824783478447854786478747884789479047914792479347944795479647974798479948004801480248034804480548064807480848094810481148124813481448154816481748184819482048214822482348244825482648274828482948304831483248334834483548364837483848394840484148424843484448454846484748484849485048514852485348544855485648574858485948604861486248634864486548664867486848694870487148724873487448754876487748784879488048814882488348844885488648874888488948904891489248934894489548964897489848994900490149024903490449054906490749084909491049114912491349144915491649174918491949204921492249234924492549264927492849294930493149324933493449354936493749384939494049414942494349444945494649474948494949504951495249534954495549564957495849594960496149624963496449654966496749684969497049714972497349744975497649774978497949804981498249834984498549864987498849894990499149924993499449954996499749984999500050015002500350045005500650075008500950105011501250135014501550165017501850195020502150225023502450255026502750285029503050315032503350345035503650375038503950405041504250435044504550465047504850495050505150525053505450555056505750585059506050615062506350645065506650675068506950705071507250735074507550765077507850795080508150825083508450855086508750885089509050915092509350945095509650975098509951005101510251035104510551065107510851095110511151125113511451155116511751185119512051215122512351245125512651275128512951305131513251335134513551365137513851395140514151425143514451455146514751485149515051515152515351545155515651575158515951605161516251635164516551665167516851695170517151725173517451755176517751785179518051815182518351845185518651875188518951905191519251935194519551965197519851995200520152025203520452055206520752085209521052115212521352145215521652175218521952205221522252235224522552265227522852295230523152325233523452355236523752385239524052415242524352445245524652475248524952505251525252535254525552565257525852595260526152625263526452655266526752685269527052715272527352745275527652775278527952805281528252835284528552865287528852895290529152925293529452955296529752985299530053015302530353045305530653075308530953105311531253135314531553165317531853195320532153225323532453255326532753285329533053315332533353345335533653375338533953405341534253435344534553465347534853495350535153525353535453555356535753585359536053615362536353645365536653675368536953705371537253735374537553765377537853795380538153825383538453855386538753885389539053915392539353945395539653975398539954005401540254035404540554065407540854095410541154125413541454155416541754185419542054215422542354245425542654275428542954305431543254335434543554365437543854395440544154425443544454455446544754485449545054515452545354545455545654575458545954605461546254635464546554665467546854695470547154725473547454755476547754785479548054815482548354845485548654875488548954905491549254935494549554965497549854995500550155025503550455055506550755085509551055115512551355145515551655175518551955205521552255235524552555265527552855295530553155325533553455355536553755385539554055415542554355445545554655475548554955505551555255535554555555565557555855595560556155625563556455655566556755685569557055715572557355745575557655775578557955805581558255835584558555865587558855895590559155925593559455955596559755985599560056015602560356045605560656075608560956105611561256135614561556165617561856195620562156225623562456255626562756285629563056315632563356345635563656375638563956405641564256435644564556465647564856495650565156525653565456555656565756585659566056615662566356645665566656675668566956705671567256735674567556765677567856795680568156825683568456855686568756885689569056915692569356945695569656975698569957005701570257035704570557065707570857095710571157125713571457155716571757185719572057215722572357245725572657275728572957305731573257335734573557365737573857395740574157425743574457455746574757485749575057515752575357545755575657575758575957605761576257635764576557665767576857695770577157725773577457755776577757785779578057815782578357845785578657875788578957905791579257935794579557965797579857995800580158025803580458055806580758085809581058115812581358145815581658175818581958205821582258235824582558265827582858295830583158325833583458355836583758385839584058415842584358445845584658475848584958505851585258535854585558565857585858595860586158625863586458655866586758685869587058715872587358745875587658775878587958805881588258835884588558865887588858895890589158925893589458955896589758985899590059015902590359045905590659075908590959105911591259135914591559165917591859195920592159225923592459255926592759285929593059315932593359345935593659375938593959405941594259435944594559465947594859495950595159525953595459555956595759585959596059615962596359645965596659675968596959705971597259735974597559765977597859795980598159825983598459855986598759885989599059915992599359945995599659975998599960006001600260036004600560066007600860096010601160126013601460156016601760186019602060216022602360246025602660276028602960306031603260336034603560366037603860396040604160426043604460456046604760486049605060516052605360546055605660576058605960606061606260636064606560666067606860696070607160726073607460756076607760786079608060816082608360846085608660876088608960906091609260936094609560966097609860996100610161026103610461056106610761086109611061116112611361146115611661176118611961206121612261236124612561266127612861296130613161326133613461356136613761386139614061416142614361446145614661476148614961506151615261536154615561566157615861596160616161626163616461656166616761686169617061716172617361746175617661776178617961806181618261836184618561866187618861896190619161926193619461956196619761986199620062016202620362046205620662076208620962106211621262136214621562166217621862196220622162226223622462256226622762286229623062316232623362346235623662376238623962406241624262436244624562466247624862496250625162526253625462556256625762586259626062616262626362646265626662676268626962706271627262736274627562766277627862796280628162826283628462856286628762886289629062916292629362946295629662976298629963006301630263036304630563066307630863096310631163126313631463156316631763186319632063216322632363246325632663276328632963306331633263336334633563366337633863396340634163426343634463456346634763486349635063516352635363546355635663576358635963606361636263636364636563666367636863696370637163726373637463756376637763786379638063816382638363846385638663876388638963906391639263936394639563966397639863996400640164026403640464056406640764086409641064116412641364146415641664176418641964206421642264236424642564266427642864296430643164326433643464356436643764386439644064416442644364446445644664476448644964506451645264536454645564566457645864596460646164626463646464656466646764686469647064716472647364746475647664776478647964806481648264836484648564866487648864896490649164926493649464956496649764986499650065016502650365046505650665076508650965106511651265136514651565166517651865196520652165226523652465256526652765286529653065316532653365346535653665376538653965406541654265436544654565466547654865496550655165526553655465556556655765586559656065616562656365646565656665676568656965706571657265736574657565766577657865796580658165826583658465856586658765886589659065916592659365946595659665976598659966006601660266036604660566066607660866096610661166126613661466156616661766186619662066216622662366246625662666276628662966306631663266336634663566366637663866396640664166426643664466456646664766486649665066516652665366546655665666576658665966606661666266636664666566666667666866696670667166726673667466756676667766786679668066816682668366846685668666876688668966906691669266936694669566966697669866996700670167026703670467056706670767086709671067116712671367146715671667176718671967206721672267236724672567266727672867296730673167326733673467356736673767386739674067416742674367446745674667476748674967506751675267536754675567566757675867596760676167626763676467656766676767686769677067716772677367746775677667776778677967806781678267836784678567866787678867896790679167926793679467956796679767986799680068016802680368046805680668076808680968106811681268136814681568166817681868196820682168226823682468256826682768286829683068316832683368346835683668376838683968406841684268436844684568466847684868496850685168526853685468556856685768586859686068616862686368646865686668676868686968706871687268736874687568766877687868796880688168826883688468856886688768886889689068916892689368946895689668976898689969006901690269036904690569066907690869096910691169126913691469156916691769186919692069216922692369246925692669276928692969306931693269336934693569366937693869396940694169426943694469456946694769486949695069516952695369546955695669576958695969606961696269636964696569666967696869696970697169726973697469756976697769786979698069816982698369846985698669876988698969906991699269936994699569966997699869997000700170027003700470057006700770087009701070117012701370147015701670177018701970207021702270237024702570267027702870297030703170327033703470357036703770387039704070417042704370447045704670477048704970507051705270537054705570567057705870597060706170627063706470657066706770687069707070717072707370747075707670777078707970807081708270837084708570867087708870897090709170927093709470957096709770987099710071017102710371047105710671077108710971107111711271137114711571167117711871197120712171227123712471257126712771287129713071317132713371347135713671377138713971407141714271437144714571467147714871497150715171527153715471557156715771587159716071617162716371647165716671677168716971707171717271737174717571767177717871797180718171827183718471857186718771887189719071917192719371947195719671977198719972007201720272037204720572067207720872097210721172127213721472157216721772187219722072217222722372247225722672277228722972307231723272337234723572367237723872397240724172427243724472457246724772487249725072517252725372547255725672577258725972607261726272637264726572667267726872697270727172727273727472757276727772787279728072817282728372847285728672877288728972907291729272937294729572967297729872997300730173027303730473057306730773087309731073117312731373147315731673177318731973207321732273237324732573267327732873297330733173327333733473357336733773387339734073417342734373447345734673477348734973507351735273537354735573567357735873597360736173627363736473657366736773687369737073717372737373747375737673777378737973807381738273837384738573867387738873897390739173927393739473957396739773987399740074017402740374047405740674077408740974107411741274137414741574167417741874197420742174227423742474257426742774287429743074317432743374347435743674377438743974407441744274437444744574467447744874497450745174527453745474557456745774587459746074617462746374647465746674677468746974707471747274737474747574767477747874797480748174827483748474857486748774887489749074917492749374947495749674977498749975007501750275037504750575067507750875097510751175127513751475157516751775187519752075217522752375247525752675277528752975307531753275337534753575367537753875397540754175427543754475457546754775487549755075517552755375547555755675577558755975607561756275637564756575667567756875697570757175727573757475757576757775787579758075817582758375847585758675877588758975907591759275937594759575967597759875997600
  1. #define _CRT_SECURE_NO_DEPRECATE // Disables "unsafe" warnings on Windows
  2. #define _USE_MATH_DEFINES // For M_PI on MSVC
  3. #include "ggml-backend.h"
  4. #include "ggml-impl.h"
  5. #include "ggml-threading.h"
  6. #include "ggml-cpu.h"
  7. #include "ggml.h"
  8. // FIXME: required here for quantization functions
  9. #include "ggml-quants.h"
  10. #ifdef GGML_USE_CPU_HBM
  11. #include <hbwmalloc.h>
  12. #endif
  13. #if defined(_MSC_VER) || defined(__MINGW32__)
  14. #include <malloc.h> // using malloc.h with MSC/MINGW
  15. #elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
  16. #include <alloca.h>
  17. #endif
  18. #include <assert.h>
  19. #include <errno.h>
  20. #include <time.h>
  21. #include <math.h>
  22. #include <stdlib.h>
  23. #include <string.h>
  24. #include <stdint.h>
  25. #include <inttypes.h>
  26. #include <stdio.h>
  27. #include <float.h>
  28. #include <limits.h>
  29. #include <stdarg.h>
  30. #include <signal.h>
  31. #if defined(__gnu_linux__)
  32. #include <syscall.h>
  33. #endif
  34. #if defined(__APPLE__)
  35. #include <unistd.h>
  36. #include <mach/mach.h>
  37. #include <TargetConditionals.h>
  38. #endif
  39. #if defined(_WIN32)
  40. #define WIN32_LEAN_AND_MEAN
  41. #ifndef NOMINMAX
  42. #define NOMINMAX
  43. #endif
  44. #include <windows.h>
  45. #endif
  46. #define UNUSED GGML_UNUSED
  47. #if defined(_MSC_VER)
  48. #define m512bh(p) p
  49. #define m512i(p) p
  50. #else
  51. #define m512bh(p) (__m512bh)(p)
  52. #define m512i(p) (__m512i)(p)
  53. #endif
  54. #if defined(__linux__) || \
  55. defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \
  56. (defined(__APPLE__) && !TARGET_OS_TV && !TARGET_OS_WATCH)
  57. #include <unistd.h>
  58. #include <sys/types.h>
  59. #include <sys/stat.h>
  60. #include <sys/wait.h>
  61. #if defined(__linux__)
  62. #include <sys/prctl.h>
  63. #endif
  64. #if defined(__ANDROID__)
  65. #include <unwind.h>
  66. #include <dlfcn.h>
  67. #include <stdio.h>
  68. struct backtrace_state {
  69. void ** current;
  70. void ** end;
  71. };
  72. static _Unwind_Reason_Code unwind_callback(struct _Unwind_Context* context, void* arg) {
  73. struct backtrace_state * state = (struct backtrace_state *)arg;
  74. uintptr_t pc = _Unwind_GetIP(context);
  75. if (pc) {
  76. if (state->current == state->end) {
  77. return _URC_END_OF_STACK;
  78. } else {
  79. *state->current++ = (void*)pc;
  80. }
  81. }
  82. return _URC_NO_REASON;
  83. }
  84. static void ggml_print_backtrace_symbols(void) {
  85. const int max = 100;
  86. void* buffer[max];
  87. struct backtrace_state state = {buffer, buffer + max};
  88. _Unwind_Backtrace(unwind_callback, &state);
  89. int count = state.current - buffer;
  90. for (int idx = 0; idx < count; ++idx) {
  91. const void * addr = buffer[idx];
  92. const char * symbol = "";
  93. Dl_info info;
  94. if (dladdr(addr, &info) && info.dli_sname) {
  95. symbol = info.dli_sname;
  96. }
  97. fprintf(stderr, "%d: %p %s\n", idx, addr, symbol);
  98. }
  99. }
  100. #elif defined(__linux__) && defined(__GLIBC__)
  101. #include <execinfo.h>
  102. static void ggml_print_backtrace_symbols(void) {
  103. void * trace[100];
  104. int nptrs = backtrace(trace, sizeof(trace)/sizeof(trace[0]));
  105. backtrace_symbols_fd(trace, nptrs, STDERR_FILENO);
  106. }
  107. #elif defined(__APPLE__)
  108. #include <execinfo.h>
  109. static void ggml_print_backtrace_symbols(void) {
  110. void * trace[100];
  111. int nptrs = backtrace(trace, sizeof(trace)/sizeof(trace[0]));
  112. backtrace_symbols_fd(trace, nptrs, STDERR_FILENO);
  113. }
  114. #else
  115. static void ggml_print_backtrace_symbols(void) {
  116. // platform not supported
  117. }
  118. #endif
  119. void ggml_print_backtrace(void) {
  120. const char * GGML_NO_BACKTRACE = getenv("GGML_NO_BACKTRACE");
  121. if (GGML_NO_BACKTRACE) {
  122. return;
  123. }
  124. #if defined(__APPLE__)
  125. // On macOS, fork+debugger attachment is problematic due to:
  126. // 1. libdispatch "poisons" forked child processes
  127. // 2. lldb has issues attaching to parent from forked child
  128. // Use simple backtrace() instead to avoid Terminal.app crashes
  129. const char * GGML_BACKTRACE_LLDB = getenv("GGML_BACKTRACE_LLDB");
  130. if (!GGML_BACKTRACE_LLDB) {
  131. fprintf(stderr, "WARNING: Using native backtrace. Set GGML_BACKTRACE_LLDB for more info.\n");
  132. fprintf(stderr, "WARNING: GGML_BACKTRACE_LLDB may cause native MacOS Terminal.app to crash.\n");
  133. fprintf(stderr, "See: https://github.com/ggml-org/llama.cpp/pull/17869\n");
  134. ggml_print_backtrace_symbols();
  135. return;
  136. }
  137. #endif
  138. #if defined(__linux__)
  139. FILE * f = fopen("/proc/self/status", "r");
  140. size_t size = 0;
  141. char * line = NULL;
  142. ssize_t length = 0;
  143. while ((length = getline(&line, &size, f)) > 0) {
  144. if (!strncmp(line, "TracerPid:", sizeof("TracerPid:") - 1) &&
  145. (length != sizeof("TracerPid:\t0\n") - 1 || line[length - 2] != '0')) {
  146. // Already being debugged, and the breakpoint is the later abort()
  147. free(line);
  148. fclose(f);
  149. return;
  150. }
  151. }
  152. free(line);
  153. fclose(f);
  154. int lock[2] = { -1, -1 };
  155. (void) !pipe(lock); // Don't start gdb until after PR_SET_PTRACER
  156. #endif
  157. const int parent_pid = getpid();
  158. const int child_pid = fork();
  159. if (child_pid < 0) { // error
  160. #if defined(__linux__)
  161. close(lock[1]);
  162. close(lock[0]);
  163. #endif
  164. return;
  165. } else if (child_pid == 0) { // child
  166. char attach[32];
  167. snprintf(attach, sizeof(attach), "attach %d", parent_pid);
  168. #if defined(__linux__)
  169. close(lock[1]);
  170. (void) !read(lock[0], lock, 1);
  171. close(lock[0]);
  172. #endif
  173. // try gdb
  174. execlp("gdb", "gdb", "--batch",
  175. "-ex", "set style enabled on",
  176. "-ex", attach,
  177. "-ex", "bt -frame-info source-and-location",
  178. "-ex", "detach",
  179. "-ex", "quit",
  180. (char *) NULL);
  181. // try lldb
  182. execlp("lldb", "lldb", "--batch",
  183. "-o", "bt",
  184. "-o", "quit",
  185. "-p", &attach[sizeof("attach ") - 1],
  186. (char *) NULL);
  187. // gdb failed, fallback to backtrace_symbols
  188. ggml_print_backtrace_symbols();
  189. _Exit(0);
  190. } else { // parent
  191. #if defined(__linux__)
  192. prctl(PR_SET_PTRACER, child_pid);
  193. close(lock[1]);
  194. close(lock[0]);
  195. #endif
  196. waitpid(child_pid, NULL, 0);
  197. }
  198. }
  199. #else
  200. void ggml_print_backtrace(void) {
  201. // platform not supported
  202. }
  203. #endif
  204. static ggml_abort_callback_t g_abort_callback = NULL;
  205. // Set the abort callback (passing null will restore original abort functionality: printing a message to stdout)
  206. GGML_API ggml_abort_callback_t ggml_set_abort_callback(ggml_abort_callback_t callback) {
  207. ggml_abort_callback_t ret_val = g_abort_callback;
  208. g_abort_callback = callback;
  209. return ret_val;
  210. }
  211. void ggml_abort(const char * file, int line, const char * fmt, ...) {
  212. fflush(stdout);
  213. char message[2048];
  214. int offset = snprintf(message, sizeof(message), "%s:%d: ", file, line);
  215. va_list args;
  216. va_start(args, fmt);
  217. vsnprintf(message + offset, sizeof(message) - offset, fmt, args);
  218. va_end(args);
  219. if (g_abort_callback) {
  220. g_abort_callback(message);
  221. } else {
  222. // default: print error and backtrace to stderr
  223. fprintf(stderr, "%s\n", message);
  224. ggml_print_backtrace();
  225. }
  226. abort();
  227. }
  228. // ggml_print_backtrace is registered with std::set_terminate by ggml.cpp
  229. //
  230. // logging
  231. //
  232. struct ggml_logger_state {
  233. ggml_log_callback log_callback;
  234. void * log_callback_user_data;
  235. };
  236. static struct ggml_logger_state g_logger_state = {ggml_log_callback_default, NULL};
  237. static void ggml_log_internal_v(enum ggml_log_level level, const char * format, va_list args) {
  238. if (format == NULL) {
  239. return;
  240. }
  241. va_list args_copy;
  242. va_copy(args_copy, args);
  243. char buffer[128];
  244. int len = vsnprintf(buffer, 128, format, args);
  245. if (len < 128) {
  246. g_logger_state.log_callback(level, buffer, g_logger_state.log_callback_user_data);
  247. } else {
  248. char * buffer2 = (char *) calloc(len + 1, sizeof(char));
  249. vsnprintf(buffer2, len + 1, format, args_copy);
  250. buffer2[len] = 0;
  251. g_logger_state.log_callback(level, buffer2, g_logger_state.log_callback_user_data);
  252. free(buffer2);
  253. }
  254. va_end(args_copy);
  255. }
  256. void ggml_log_internal(enum ggml_log_level level, const char * format, ...) {
  257. va_list args;
  258. va_start(args, format);
  259. ggml_log_internal_v(level, format, args);
  260. va_end(args);
  261. }
  262. void ggml_log_callback_default(enum ggml_log_level level, const char * text, void * user_data) {
  263. (void) level;
  264. (void) user_data;
  265. fputs(text, stderr);
  266. fflush(stderr);
  267. }
  268. //
  269. // end of logging block
  270. //
  271. #ifdef GGML_USE_ACCELERATE
  272. // uncomment to use vDSP for soft max computation
  273. // note: not sure if it is actually faster
  274. //#define GGML_SOFT_MAX_ACCELERATE
  275. #endif
  276. void * ggml_aligned_malloc(size_t size) {
  277. #if defined(__s390x__)
  278. const int alignment = 256;
  279. #else
  280. const int alignment = 64;
  281. #endif
  282. #if defined(_MSC_VER) || defined(__MINGW32__)
  283. return _aligned_malloc(size, alignment);
  284. #else
  285. if (size == 0) {
  286. GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for ggml_aligned_malloc!\n");
  287. return NULL;
  288. }
  289. void * aligned_memory = NULL;
  290. #ifdef GGML_USE_CPU_HBM
  291. int result = hbw_posix_memalign(&aligned_memory, alignment, size);
  292. #elif TARGET_OS_OSX
  293. GGML_UNUSED(alignment);
  294. kern_return_t alloc_status = vm_allocate((vm_map_t) mach_task_self(), (vm_address_t *) &aligned_memory, size, VM_FLAGS_ANYWHERE);
  295. int result = EFAULT;
  296. switch (alloc_status) {
  297. case KERN_SUCCESS:
  298. result = 0;
  299. break;
  300. case KERN_INVALID_ADDRESS:
  301. result = EINVAL;
  302. break;
  303. case KERN_NO_SPACE:
  304. result = ENOMEM;
  305. break;
  306. default:
  307. result = EFAULT;
  308. break;
  309. }
  310. #else
  311. int result = posix_memalign(&aligned_memory, alignment, size);
  312. #endif
  313. if (result != 0) {
  314. // Handle allocation failure
  315. const char *error_desc = "unknown allocation error";
  316. switch (result) {
  317. case EINVAL:
  318. error_desc = "invalid alignment value";
  319. break;
  320. case ENOMEM:
  321. error_desc = "insufficient memory";
  322. break;
  323. }
  324. GGML_LOG_ERROR("%s: %s (attempted to allocate %6.2f MB)\n", __func__, error_desc, size/(1024.0*1024.0));
  325. return NULL;
  326. }
  327. return aligned_memory;
  328. #endif
  329. }
  330. void ggml_aligned_free(void * ptr, size_t size) {
  331. GGML_UNUSED(size);
  332. #if defined(_MSC_VER) || defined(__MINGW32__)
  333. _aligned_free(ptr);
  334. #elif GGML_USE_CPU_HBM
  335. if (ptr != NULL) {
  336. hbw_free(ptr);
  337. }
  338. #elif TARGET_OS_OSX
  339. if (ptr != NULL) {
  340. vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ptr, size);
  341. }
  342. #else
  343. free(ptr);
  344. #endif
  345. }
  346. inline static void * ggml_malloc(size_t size) {
  347. if (size == 0) {
  348. GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for ggml_malloc!\n");
  349. return NULL;
  350. }
  351. void * result = malloc(size);
  352. if (result == NULL) {
  353. GGML_LOG_ERROR("%s: failed to allocate %6.2f MB\n", __func__, size/(1024.0*1024.0));
  354. GGML_ABORT("fatal error");
  355. }
  356. return result;
  357. }
  358. // calloc
  359. inline static void * ggml_calloc(size_t num, size_t size) {
  360. if (num == 0 || size == 0) {
  361. GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for ggml_calloc!\n");
  362. return NULL;
  363. }
  364. void * result = calloc(num, size);
  365. if (result == NULL) {
  366. GGML_LOG_ERROR("%s: failed to allocate %6.2f MB\n", __func__, size/(1024.0*1024.0));
  367. GGML_ABORT("fatal error");
  368. }
  369. return result;
  370. }
  371. #define GGML_MALLOC(size) ggml_malloc(size)
  372. #define GGML_CALLOC(num, size) ggml_calloc(num, size)
  373. #define GGML_FREE(ptr) free(ptr)
  374. const char * ggml_status_to_string(enum ggml_status status) {
  375. switch (status) {
  376. case GGML_STATUS_ALLOC_FAILED: return "GGML status: error (failed to allocate memory)";
  377. case GGML_STATUS_FAILED: return "GGML status: error (operation failed)";
  378. case GGML_STATUS_SUCCESS: return "GGML status: success";
  379. case GGML_STATUS_ABORTED: return "GGML status: warning (operation aborted)";
  380. }
  381. return "GGML status: unknown";
  382. }
  383. float ggml_fp16_to_fp32(ggml_fp16_t x) {
  384. #define ggml_fp16_to_fp32 do_not_use__ggml_fp16_to_fp32__in_ggml
  385. return GGML_FP16_TO_FP32(x);
  386. }
  387. ggml_fp16_t ggml_fp32_to_fp16(float x) {
  388. #define ggml_fp32_to_fp16 do_not_use__ggml_fp32_to_fp16__in_ggml
  389. return GGML_FP32_TO_FP16(x);
  390. }
  391. float ggml_bf16_to_fp32(ggml_bf16_t x) {
  392. #define ggml_bf16_to_fp32 do_not_use__ggml_bf16_to_fp32__in_ggml
  393. return GGML_BF16_TO_FP32(x); // it just left shifts
  394. }
  395. ggml_bf16_t ggml_fp32_to_bf16(float x) {
  396. #define ggml_fp32_to_bf16 do_not_use__ggml_fp32_to_bf16__in_ggml
  397. return GGML_FP32_TO_BF16(x);
  398. }
  399. void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) {
  400. for (int64_t i = 0; i < n; i++) {
  401. y[i] = GGML_FP16_TO_FP32(x[i]);
  402. }
  403. }
  404. void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) {
  405. int i = 0;
  406. for (; i < n; ++i) {
  407. y[i] = GGML_FP32_TO_FP16(x[i]);
  408. }
  409. }
  410. void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) {
  411. int i = 0;
  412. for (; i < n; ++i) {
  413. y[i] = GGML_BF16_TO_FP32(x[i]);
  414. }
  415. }
  416. void ggml_fp32_to_bf16_row_ref(const float * x, ggml_bf16_t * y, int64_t n) {
  417. for (int i = 0; i < n; i++) {
  418. y[i] = ggml_compute_fp32_to_bf16(x[i]);
  419. }
  420. }
  421. void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
  422. int i = 0;
  423. #if defined(__AVX512BF16__)
  424. // subnormals are flushed to zero on this platform
  425. for (; i + 32 <= n; i += 32) {
  426. _mm512_storeu_si512(
  427. (__m512i *)(y + i),
  428. m512i(_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
  429. _mm512_loadu_ps(x + i))));
  430. }
  431. #endif
  432. for (; i < n; i++) {
  433. y[i] = GGML_FP32_TO_BF16(x[i]);
  434. }
  435. }
  436. bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b) {
  437. return memcmp(guid_a, guid_b, sizeof(ggml_guid)) == 0;
  438. }
  439. const char * ggml_version(void) {
  440. return GGML_VERSION;
  441. }
  442. const char * ggml_commit(void) {
  443. return GGML_COMMIT;
  444. }
  445. //
  446. // timing
  447. //
  448. #if defined(_MSC_VER) || defined(__MINGW32__)
  449. static int64_t timer_freq, timer_start;
  450. void ggml_time_init(void) {
  451. LARGE_INTEGER t;
  452. QueryPerformanceFrequency(&t);
  453. timer_freq = t.QuadPart;
  454. // The multiplication by 1000 or 1000000 below can cause an overflow if timer_freq
  455. // and the uptime is high enough.
  456. // We subtract the program start time to reduce the likelihood of that happening.
  457. QueryPerformanceCounter(&t);
  458. timer_start = t.QuadPart;
  459. }
  460. int64_t ggml_time_ms(void) {
  461. LARGE_INTEGER t;
  462. QueryPerformanceCounter(&t);
  463. return ((t.QuadPart-timer_start) * 1000) / timer_freq;
  464. }
  465. int64_t ggml_time_us(void) {
  466. LARGE_INTEGER t;
  467. QueryPerformanceCounter(&t);
  468. return ((t.QuadPart-timer_start) * 1000000) / timer_freq;
  469. }
  470. #else
  471. void ggml_time_init(void) {}
  472. int64_t ggml_time_ms(void) {
  473. struct timespec ts;
  474. clock_gettime(CLOCK_MONOTONIC, &ts);
  475. return (int64_t)ts.tv_sec*1000 + (int64_t)ts.tv_nsec/1000000;
  476. }
  477. int64_t ggml_time_us(void) {
  478. struct timespec ts;
  479. clock_gettime(CLOCK_MONOTONIC, &ts);
  480. return (int64_t)ts.tv_sec*1000000 + (int64_t)ts.tv_nsec/1000;
  481. }
  482. #endif
  483. int64_t ggml_cycles(void) {
  484. return clock();
  485. }
  486. int64_t ggml_cycles_per_ms(void) {
  487. return CLOCKS_PER_SEC/1000;
  488. }
  489. //
  490. // cross-platform UTF-8 file paths
  491. //
  492. #ifdef _WIN32
  493. static wchar_t * ggml_mbstowcs(const char * mbs) {
  494. int wlen = MultiByteToWideChar(CP_UTF8, 0, mbs, -1, NULL, 0);
  495. if (!wlen) {
  496. errno = EINVAL;
  497. return NULL;
  498. }
  499. wchar_t * wbuf = GGML_MALLOC(wlen * sizeof(wchar_t));
  500. wlen = MultiByteToWideChar(CP_UTF8, 0, mbs, -1, wbuf, wlen);
  501. if (!wlen) {
  502. GGML_FREE(wbuf);
  503. errno = EINVAL;
  504. return NULL;
  505. }
  506. return wbuf;
  507. }
  508. #endif
  509. FILE * ggml_fopen(const char * fname, const char * mode) {
  510. #ifdef _WIN32
  511. FILE * file = NULL;
  512. // convert fname (UTF-8)
  513. wchar_t * wfname = ggml_mbstowcs(fname);
  514. if (wfname) {
  515. // convert mode (ANSI)
  516. wchar_t * wmode = GGML_MALLOC((strlen(mode) + 1) * sizeof(wchar_t));
  517. wchar_t * wmode_p = wmode;
  518. do {
  519. *wmode_p++ = (wchar_t)*mode;
  520. } while (*mode++);
  521. // open file
  522. file = _wfopen(wfname, wmode);
  523. GGML_FREE(wfname);
  524. GGML_FREE(wmode);
  525. }
  526. return file;
  527. #else
  528. return fopen(fname, mode);
  529. #endif
  530. }
  531. static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
  532. [GGML_TYPE_I8] = {
  533. .type_name = "i8",
  534. .blck_size = 1,
  535. .type_size = sizeof(int8_t),
  536. .is_quantized = false,
  537. },
  538. [GGML_TYPE_I16] = {
  539. .type_name = "i16",
  540. .blck_size = 1,
  541. .type_size = sizeof(int16_t),
  542. .is_quantized = false,
  543. },
  544. [GGML_TYPE_I32] = {
  545. .type_name = "i32",
  546. .blck_size = 1,
  547. .type_size = sizeof(int32_t),
  548. .is_quantized = false,
  549. },
  550. [GGML_TYPE_I64] = {
  551. .type_name = "i64",
  552. .blck_size = 1,
  553. .type_size = sizeof(int64_t),
  554. .is_quantized = false,
  555. },
  556. [GGML_TYPE_F64] = {
  557. .type_name = "f64",
  558. .blck_size = 1,
  559. .type_size = sizeof(double),
  560. .is_quantized = false,
  561. },
  562. [GGML_TYPE_F32] = {
  563. .type_name = "f32",
  564. .blck_size = 1,
  565. .type_size = sizeof(float),
  566. .is_quantized = false,
  567. },
  568. [GGML_TYPE_F16] = {
  569. .type_name = "f16",
  570. .blck_size = 1,
  571. .type_size = sizeof(ggml_fp16_t),
  572. .is_quantized = false,
  573. .to_float = (ggml_to_float_t) ggml_fp16_to_fp32_row,
  574. .from_float_ref = (ggml_from_float_t) ggml_fp32_to_fp16_row,
  575. },
  576. [GGML_TYPE_Q4_0] = {
  577. .type_name = "q4_0",
  578. .blck_size = QK4_0,
  579. .type_size = sizeof(block_q4_0),
  580. .is_quantized = true,
  581. .to_float = (ggml_to_float_t) dequantize_row_q4_0,
  582. .from_float_ref = (ggml_from_float_t) quantize_row_q4_0_ref,
  583. },
  584. [GGML_TYPE_Q4_1] = {
  585. .type_name = "q4_1",
  586. .blck_size = QK4_1,
  587. .type_size = sizeof(block_q4_1),
  588. .is_quantized = true,
  589. .to_float = (ggml_to_float_t) dequantize_row_q4_1,
  590. .from_float_ref = (ggml_from_float_t) quantize_row_q4_1_ref,
  591. },
  592. [4] = { // GGML_TYPE_Q4_2
  593. .type_name = "DEPRECATED",
  594. .blck_size = 0,
  595. .type_size = 0,
  596. .is_quantized = false,
  597. },
  598. [5] = { // GGML_TYPE_Q4_3
  599. .type_name = "DEPRECATED",
  600. .blck_size = 0,
  601. .type_size = 0,
  602. .is_quantized = false,
  603. },
  604. [GGML_TYPE_Q5_0] = {
  605. .type_name = "q5_0",
  606. .blck_size = QK5_0,
  607. .type_size = sizeof(block_q5_0),
  608. .is_quantized = true,
  609. .to_float = (ggml_to_float_t) dequantize_row_q5_0,
  610. .from_float_ref = (ggml_from_float_t) quantize_row_q5_0_ref,
  611. },
  612. [GGML_TYPE_Q5_1] = {
  613. .type_name = "q5_1",
  614. .blck_size = QK5_1,
  615. .type_size = sizeof(block_q5_1),
  616. .is_quantized = true,
  617. .to_float = (ggml_to_float_t) dequantize_row_q5_1,
  618. .from_float_ref = (ggml_from_float_t) quantize_row_q5_1_ref,
  619. },
  620. [GGML_TYPE_Q8_0] = {
  621. .type_name = "q8_0",
  622. .blck_size = QK8_0,
  623. .type_size = sizeof(block_q8_0),
  624. .is_quantized = true,
  625. .to_float = (ggml_to_float_t) dequantize_row_q8_0,
  626. .from_float_ref = (ggml_from_float_t) quantize_row_q8_0_ref,
  627. },
  628. [GGML_TYPE_Q8_1] = {
  629. .type_name = "q8_1",
  630. .blck_size = QK8_1,
  631. .type_size = sizeof(block_q8_1),
  632. .is_quantized = true,
  633. .from_float_ref = (ggml_from_float_t) quantize_row_q8_1_ref,
  634. },
  635. [GGML_TYPE_MXFP4] = {
  636. .type_name = "mxfp4",
  637. .blck_size = QK_MXFP4,
  638. .type_size = sizeof(block_mxfp4),
  639. .is_quantized = true,
  640. .to_float = (ggml_to_float_t) dequantize_row_mxfp4,
  641. .from_float_ref = (ggml_from_float_t)quantize_row_mxfp4_ref,
  642. },
  643. [GGML_TYPE_Q2_K] = {
  644. .type_name = "q2_K",
  645. .blck_size = QK_K,
  646. .type_size = sizeof(block_q2_K),
  647. .is_quantized = true,
  648. .to_float = (ggml_to_float_t) dequantize_row_q2_K,
  649. .from_float_ref = (ggml_from_float_t) quantize_row_q2_K_ref,
  650. },
  651. [GGML_TYPE_Q3_K] = {
  652. .type_name = "q3_K",
  653. .blck_size = QK_K,
  654. .type_size = sizeof(block_q3_K),
  655. .is_quantized = true,
  656. .to_float = (ggml_to_float_t) dequantize_row_q3_K,
  657. .from_float_ref = (ggml_from_float_t) quantize_row_q3_K_ref,
  658. },
  659. [GGML_TYPE_Q4_K] = {
  660. .type_name = "q4_K",
  661. .blck_size = QK_K,
  662. .type_size = sizeof(block_q4_K),
  663. .is_quantized = true,
  664. .to_float = (ggml_to_float_t) dequantize_row_q4_K,
  665. .from_float_ref = (ggml_from_float_t) quantize_row_q4_K_ref,
  666. },
  667. [GGML_TYPE_Q5_K] = {
  668. .type_name = "q5_K",
  669. .blck_size = QK_K,
  670. .type_size = sizeof(block_q5_K),
  671. .is_quantized = true,
  672. .to_float = (ggml_to_float_t) dequantize_row_q5_K,
  673. .from_float_ref = (ggml_from_float_t) quantize_row_q5_K_ref,
  674. },
  675. [GGML_TYPE_Q6_K] = {
  676. .type_name = "q6_K",
  677. .blck_size = QK_K,
  678. .type_size = sizeof(block_q6_K),
  679. .is_quantized = true,
  680. .to_float = (ggml_to_float_t) dequantize_row_q6_K,
  681. .from_float_ref = (ggml_from_float_t) quantize_row_q6_K_ref,
  682. },
  683. [GGML_TYPE_IQ2_XXS] = {
  684. .type_name = "iq2_xxs",
  685. .blck_size = QK_K,
  686. .type_size = sizeof(block_iq2_xxs),
  687. .is_quantized = true,
  688. .to_float = (ggml_to_float_t) dequantize_row_iq2_xxs,
  689. .from_float_ref = NULL,
  690. },
  691. [GGML_TYPE_IQ2_XS] = {
  692. .type_name = "iq2_xs",
  693. .blck_size = QK_K,
  694. .type_size = sizeof(block_iq2_xs),
  695. .is_quantized = true,
  696. .to_float = (ggml_to_float_t) dequantize_row_iq2_xs,
  697. .from_float_ref = NULL,
  698. },
  699. [GGML_TYPE_IQ3_XXS] = {
  700. .type_name = "iq3_xxs",
  701. .blck_size = QK_K,
  702. .type_size = sizeof(block_iq3_xxs),
  703. .is_quantized = true,
  704. .to_float = (ggml_to_float_t) dequantize_row_iq3_xxs,
  705. .from_float_ref = (ggml_from_float_t)quantize_row_iq3_xxs_ref,
  706. },
  707. [GGML_TYPE_IQ3_S] = {
  708. .type_name = "iq3_s",
  709. .blck_size = QK_K,
  710. .type_size = sizeof(block_iq3_s),
  711. .is_quantized = true,
  712. .to_float = (ggml_to_float_t) dequantize_row_iq3_s,
  713. .from_float_ref = (ggml_from_float_t)quantize_row_iq3_s_ref,
  714. },
  715. [GGML_TYPE_IQ2_S] = {
  716. .type_name = "iq2_s",
  717. .blck_size = QK_K,
  718. .type_size = sizeof(block_iq2_s),
  719. .is_quantized = true,
  720. .to_float = (ggml_to_float_t) dequantize_row_iq2_s,
  721. .from_float_ref = (ggml_from_float_t)quantize_row_iq2_s_ref,
  722. },
  723. [GGML_TYPE_IQ1_S] = {
  724. .type_name = "iq1_s",
  725. .blck_size = QK_K,
  726. .type_size = sizeof(block_iq1_s),
  727. .is_quantized = true,
  728. .to_float = (ggml_to_float_t) dequantize_row_iq1_s,
  729. .from_float_ref = NULL,
  730. },
  731. [GGML_TYPE_IQ1_M] = {
  732. .type_name = "iq1_m",
  733. .blck_size = QK_K,
  734. .type_size = sizeof(block_iq1_m),
  735. .is_quantized = true,
  736. .to_float = (ggml_to_float_t) dequantize_row_iq1_m,
  737. .from_float_ref = NULL,
  738. },
  739. [GGML_TYPE_IQ4_NL] = {
  740. .type_name = "iq4_nl",
  741. .blck_size = QK4_NL,
  742. .type_size = sizeof(block_iq4_nl),
  743. .is_quantized = true,
  744. .to_float = (ggml_to_float_t) dequantize_row_iq4_nl,
  745. .from_float_ref = (ggml_from_float_t)quantize_row_iq4_nl_ref,
  746. },
  747. [GGML_TYPE_IQ4_XS] = {
  748. .type_name = "iq4_xs",
  749. .blck_size = QK_K,
  750. .type_size = sizeof(block_iq4_xs),
  751. .is_quantized = true,
  752. .to_float = (ggml_to_float_t) dequantize_row_iq4_xs,
  753. .from_float_ref = (ggml_from_float_t)quantize_row_iq4_xs_ref,
  754. },
  755. [GGML_TYPE_Q8_K] = {
  756. .type_name = "q8_K",
  757. .blck_size = QK_K,
  758. .type_size = sizeof(block_q8_K),
  759. .is_quantized = true,
  760. },
  761. [GGML_TYPE_BF16] = {
  762. .type_name = "bf16",
  763. .blck_size = 1,
  764. .type_size = sizeof(ggml_bf16_t),
  765. .is_quantized = false,
  766. .to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row,
  767. .from_float_ref = (ggml_from_float_t) ggml_fp32_to_bf16_row_ref,
  768. },
  769. [31] = { // GGML_TYPE_Q4_0_4_4
  770. .type_name = "TYPE_Q4_0_4_4 REMOVED, use Q4_0 with runtime repacking",
  771. .blck_size = 0,
  772. .type_size = 0,
  773. .is_quantized = false,
  774. },
  775. [32] = { // GGML_TYPE_Q4_0_4_8
  776. .type_name = "TYPE_Q4_0_4_8 REMOVED, use Q4_0 with runtime repacking",
  777. .blck_size = 0,
  778. .type_size = 0,
  779. .is_quantized = false,
  780. },
  781. [33] = { // GGML_TYPE_Q4_0_8_8
  782. .type_name = "TYPE_Q4_0_8_8 REMOVED, use Q4_0 with runtime repacking",
  783. .blck_size = 0,
  784. .type_size = 0,
  785. .is_quantized = false,
  786. },
  787. [GGML_TYPE_TQ1_0] = {
  788. .type_name = "tq1_0",
  789. .blck_size = QK_K,
  790. .type_size = sizeof(block_tq1_0),
  791. .is_quantized = true,
  792. .to_float = (ggml_to_float_t) dequantize_row_tq1_0,
  793. .from_float_ref = (ggml_from_float_t) quantize_row_tq1_0_ref,
  794. },
  795. [GGML_TYPE_TQ2_0] = {
  796. .type_name = "tq2_0",
  797. .blck_size = QK_K,
  798. .type_size = sizeof(block_tq2_0),
  799. .is_quantized = true,
  800. .to_float = (ggml_to_float_t) dequantize_row_tq2_0,
  801. .from_float_ref = (ggml_from_float_t) quantize_row_tq2_0_ref,
  802. },
  803. [36] = { // GGML_TYPE_IQ4_NL_4_4
  804. .type_name = "TYPE_IQ4_NL_4_4 REMOVED, use IQ4_NL with runtime repacking",
  805. .blck_size = 0,
  806. .type_size = 0,
  807. .is_quantized = false,
  808. },
  809. [37] = { // GGML_TYPE_IQ4_NL_4_8
  810. .type_name = "TYPE_IQ4_NL_4_8 REMOVED, use IQ4_NL with runtime repacking",
  811. .blck_size = 0,
  812. .type_size = 0,
  813. .is_quantized = false,
  814. },
  815. [38] = { // GGML_TYPE_IQ4_NL_8_8
  816. .type_name = "TYPE_IQ4_NL_8_8 REMOVED, use IQ4_NL with runtime repacking",
  817. .blck_size = 0,
  818. .type_size = 0,
  819. .is_quantized = false,
  820. },
  821. };
  822. const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type) {
  823. GGML_ASSERT(type < GGML_TYPE_COUNT);
  824. return &type_traits[type];
  825. }
  826. //
  827. // ggml object
  828. //
  829. struct ggml_object {
  830. size_t offs;
  831. size_t size;
  832. struct ggml_object * next;
  833. enum ggml_object_type type;
  834. char padding[4];
  835. };
  836. static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
  837. //
  838. // ggml context
  839. //
  840. struct ggml_context {
  841. size_t mem_size;
  842. void * mem_buffer;
  843. bool mem_buffer_owned;
  844. bool no_alloc;
  845. int n_objects;
  846. struct ggml_object * objects_begin;
  847. struct ggml_object * objects_end;
  848. };
  849. //
  850. // data types
  851. //
  852. static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
  853. "NONE",
  854. "DUP",
  855. "ADD",
  856. "ADD_ID",
  857. "ADD1",
  858. "ACC",
  859. "SUB",
  860. "MUL",
  861. "DIV",
  862. "SQR",
  863. "SQRT",
  864. "LOG",
  865. "SIN",
  866. "COS",
  867. "SUM",
  868. "SUM_ROWS",
  869. "CUMSUM",
  870. "MEAN",
  871. "ARGMAX",
  872. "COUNT_EQUAL",
  873. "REPEAT",
  874. "REPEAT_BACK",
  875. "CONCAT",
  876. "SILU_BACK",
  877. "NORM",
  878. "RMS_NORM",
  879. "RMS_NORM_BACK",
  880. "GROUP_NORM",
  881. "L2_NORM",
  882. "MUL_MAT",
  883. "MUL_MAT_ID",
  884. "OUT_PROD",
  885. "SCALE",
  886. "SET",
  887. "CPY",
  888. "CONT",
  889. "RESHAPE",
  890. "VIEW",
  891. "PERMUTE",
  892. "TRANSPOSE",
  893. "GET_ROWS",
  894. "GET_ROWS_BACK",
  895. "SET_ROWS",
  896. "DIAG",
  897. "DIAG_MASK_INF",
  898. "DIAG_MASK_ZERO",
  899. "SOFT_MAX",
  900. "SOFT_MAX_BACK",
  901. "ROPE",
  902. "ROPE_BACK",
  903. "CLAMP",
  904. "CONV_TRANSPOSE_1D",
  905. "IM2COL",
  906. "IM2COL_BACK",
  907. "IM2COL_3D",
  908. "CONV_2D",
  909. "CONV_3D",
  910. "CONV_2D_DW",
  911. "CONV_TRANSPOSE_2D",
  912. "POOL_1D",
  913. "POOL_2D",
  914. "POOL_2D_BACK",
  915. "UPSCALE",
  916. "PAD",
  917. "PAD_REFLECT_1D",
  918. "ROLL",
  919. "ARANGE",
  920. "TIMESTEP_EMBEDDING",
  921. "ARGSORT",
  922. "TOP_K",
  923. "LEAKY_RELU",
  924. "TRI",
  925. "FILL",
  926. "FLASH_ATTN_EXT",
  927. "FLASH_ATTN_BACK",
  928. "SSM_CONV",
  929. "SSM_SCAN",
  930. "WIN_PART",
  931. "WIN_UNPART",
  932. "GET_REL_POS",
  933. "ADD_REL_POS",
  934. "RWKV_WKV6",
  935. "GATED_LINEAR_ATTN",
  936. "RWKV_WKV7",
  937. "SOLVE_TRI",
  938. "UNARY",
  939. "MAP_CUSTOM1",
  940. "MAP_CUSTOM2",
  941. "MAP_CUSTOM3",
  942. "CUSTOM",
  943. "CROSS_ENTROPY_LOSS",
  944. "CROSS_ENTROPY_LOSS_BACK",
  945. "OPT_STEP_ADAMW",
  946. "OPT_STEP_SGD",
  947. "GLU",
  948. };
  949. static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
  950. static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
  951. "none",
  952. "x",
  953. "x+y",
  954. "x[i]+y",
  955. "x+y",
  956. "view(x,nb,offset)+=y->x",
  957. "x-y",
  958. "x*y",
  959. "x/y",
  960. "x^2",
  961. "√x",
  962. "log(x)",
  963. "sin(x)",
  964. "cos(x)",
  965. "Σx",
  966. "Σx_k",
  967. "cumsum(x)",
  968. "Σx/n",
  969. "argmax(x)",
  970. "count_equal(x)",
  971. "repeat(x)",
  972. "repeat_back(x)",
  973. "concat(x, y)",
  974. "silu_back(x)",
  975. "norm(x)",
  976. "rms_norm(x)",
  977. "rms_norm_back(x)",
  978. "group_norm(x)",
  979. "l2_norm(x)",
  980. "X*Y",
  981. "X[i]*Y",
  982. "X*Y",
  983. "x*v",
  984. "y-\\>view(x)",
  985. "x-\\>y",
  986. "cont(x)",
  987. "reshape(x)",
  988. "view(x)",
  989. "permute(x)",
  990. "transpose(x)",
  991. "get_rows(x)",
  992. "get_rows_back(x)",
  993. "set_rows(x)",
  994. "diag(x)",
  995. "diag_mask_inf(x)",
  996. "diag_mask_zero(x)",
  997. "soft_max(x)",
  998. "soft_max_back(x)",
  999. "rope(x)",
  1000. "rope_back(x)",
  1001. "clamp(x)",
  1002. "conv_transpose_1d(x)",
  1003. "im2col(x)",
  1004. "im2col_back(x)",
  1005. "im2col_3d(x)",
  1006. "conv_2d(x)",
  1007. "conv_3d(x)",
  1008. "conv_2d_dw(x)",
  1009. "conv_transpose_2d(x)",
  1010. "pool_1d(x)",
  1011. "pool_2d(x)",
  1012. "pool_2d_back(x)",
  1013. "upscale(x)",
  1014. "pad(x)",
  1015. "pad_reflect_1d(x)",
  1016. "roll(x)",
  1017. "arange(start, stop, step)",
  1018. "timestep_embedding(timesteps, dim, max_period)",
  1019. "argsort(x)",
  1020. "top_k(x)",
  1021. "leaky_relu(x)",
  1022. "tri(x)",
  1023. "fill(x, c)",
  1024. "flash_attn_ext(x)",
  1025. "flash_attn_back(x)",
  1026. "ssm_conv(x)",
  1027. "ssm_scan(x)",
  1028. "win_part(x)",
  1029. "win_unpart(x)",
  1030. "get_rel_pos(x)",
  1031. "add_rel_pos(x)",
  1032. "rwkv_wkv6(k, v, r, tf, td, s)",
  1033. "gated_linear_attn(k, v, q, gate, s)",
  1034. "rwkv_wkv7(r, w, k, v, a, b, s)",
  1035. "A X = B, A triangular, solve X",
  1036. "unary(x)",
  1037. "map_custom(x)",
  1038. "map_custom(x,y)",
  1039. "map_custom(x,y,z)",
  1040. "custom(x)",
  1041. "cross_entropy_loss(x,y)",
  1042. "cross_entropy_loss_back(x,y)",
  1043. "adamw(x)",
  1044. "sgd(x)",
  1045. "glu(x)",
  1046. };
  1047. static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
  1048. static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
  1049. static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
  1050. "ABS",
  1051. "SGN",
  1052. "NEG",
  1053. "STEP",
  1054. "TANH",
  1055. "ELU",
  1056. "RELU",
  1057. "SIGMOID",
  1058. "GELU",
  1059. "GELU_QUICK",
  1060. "SILU",
  1061. "HARDSWISH",
  1062. "HARDSIGMOID",
  1063. "EXP",
  1064. "EXPM1",
  1065. "SOFTPLUS",
  1066. "GELU_ERF",
  1067. "XIELU",
  1068. "FLOOR",
  1069. "CEIL",
  1070. "ROUND",
  1071. "TRUNC",
  1072. };
  1073. static_assert(GGML_UNARY_OP_COUNT == 22, "GGML_UNARY_OP_COUNT != 22");
  1074. static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
  1075. "REGLU",
  1076. "GEGLU",
  1077. "SWIGLU",
  1078. "SWIGLU_OAI",
  1079. "GEGLU_ERF",
  1080. "GEGLU_QUICK",
  1081. };
  1082. static_assert(GGML_GLU_OP_COUNT == 6, "GGML_GLU_OP_COUNT != 6");
  1083. static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
  1084. static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
  1085. ////////////////////////////////////////////////////////////////////////////////
  1086. void ggml_print_object(const struct ggml_object * obj) {
  1087. GGML_LOG_INFO(" - ggml_object: type = %d, offset = %zu, size = %zu, next = %p\n",
  1088. obj->type, obj->offs, obj->size, (const void *) obj->next);
  1089. }
  1090. void ggml_print_objects(const struct ggml_context * ctx) {
  1091. struct ggml_object * obj = ctx->objects_begin;
  1092. GGML_LOG_INFO("%s: objects in context %p:\n", __func__, (const void *) ctx);
  1093. while (obj != NULL) {
  1094. ggml_print_object(obj);
  1095. obj = obj->next;
  1096. }
  1097. GGML_LOG_INFO("%s: --- end ---\n", __func__);
  1098. }
  1099. int64_t ggml_nelements(const struct ggml_tensor * tensor) {
  1100. static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
  1101. return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
  1102. }
  1103. int64_t ggml_nrows(const struct ggml_tensor * tensor) {
  1104. static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
  1105. return tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
  1106. }
  1107. size_t ggml_nbytes(const struct ggml_tensor * tensor) {
  1108. for (int i = 0; i < GGML_MAX_DIMS; ++i) {
  1109. if (tensor->ne[i] <= 0) {
  1110. return 0;
  1111. }
  1112. }
  1113. size_t nbytes;
  1114. const size_t blck_size = ggml_blck_size(tensor->type);
  1115. if (blck_size == 1) {
  1116. nbytes = ggml_type_size(tensor->type);
  1117. for (int i = 0; i < GGML_MAX_DIMS; ++i) {
  1118. nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
  1119. }
  1120. }
  1121. else {
  1122. nbytes = tensor->ne[0]*tensor->nb[0]/blck_size;
  1123. for (int i = 1; i < GGML_MAX_DIMS; ++i) {
  1124. nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
  1125. }
  1126. }
  1127. return nbytes;
  1128. }
  1129. size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) {
  1130. return GGML_PAD(ggml_nbytes(tensor), GGML_MEM_ALIGN);
  1131. }
  1132. int64_t ggml_blck_size(enum ggml_type type) {
  1133. return type_traits[type].blck_size;
  1134. }
  1135. size_t ggml_type_size(enum ggml_type type) {
  1136. return type_traits[type].type_size;
  1137. }
  1138. size_t ggml_row_size(enum ggml_type type, int64_t ne) {
  1139. assert(ne % ggml_blck_size(type) == 0);
  1140. return ggml_type_size(type)*ne/ggml_blck_size(type);
  1141. }
  1142. double ggml_type_sizef(enum ggml_type type) {
  1143. return ((double)(type_traits[type].type_size))/type_traits[type].blck_size;
  1144. }
  1145. const char * ggml_type_name(enum ggml_type type) {
  1146. return type < GGML_TYPE_COUNT ? type_traits[type].type_name : "NONE";
  1147. }
  1148. bool ggml_is_quantized(enum ggml_type type) {
  1149. return type_traits[type].is_quantized;
  1150. }
  1151. const char * ggml_op_name(enum ggml_op op) {
  1152. return GGML_OP_NAME[op];
  1153. }
  1154. const char * ggml_op_symbol(enum ggml_op op) {
  1155. return GGML_OP_SYMBOL[op];
  1156. }
  1157. const char * ggml_unary_op_name(enum ggml_unary_op op) {
  1158. return GGML_UNARY_OP_NAME[op];
  1159. }
  1160. const char * ggml_glu_op_name(enum ggml_glu_op op) {
  1161. return GGML_GLU_OP_NAME[op];
  1162. }
  1163. const char * ggml_op_desc(const struct ggml_tensor * t) {
  1164. if (t->op == GGML_OP_UNARY) {
  1165. enum ggml_unary_op uop = ggml_get_unary_op(t);
  1166. return ggml_unary_op_name(uop);
  1167. }
  1168. if (t->op == GGML_OP_GLU) {
  1169. enum ggml_glu_op gop = ggml_get_glu_op(t);
  1170. return ggml_glu_op_name(gop);
  1171. }
  1172. return ggml_op_name(t->op);
  1173. }
  1174. size_t ggml_element_size(const struct ggml_tensor * tensor) {
  1175. return ggml_type_size(tensor->type);
  1176. }
  1177. bool ggml_is_scalar(const struct ggml_tensor * tensor) {
  1178. static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
  1179. return tensor->ne[0] == 1 && tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1;
  1180. }
  1181. bool ggml_is_vector(const struct ggml_tensor * tensor) {
  1182. static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
  1183. return tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1;
  1184. }
  1185. bool ggml_is_matrix(const struct ggml_tensor * tensor) {
  1186. static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
  1187. return tensor->ne[2] == 1 && tensor->ne[3] == 1;
  1188. }
  1189. bool ggml_is_3d(const struct ggml_tensor * tensor) {
  1190. return tensor->ne[3] == 1;
  1191. }
  1192. int ggml_n_dims(const struct ggml_tensor * tensor) {
  1193. for (int i = GGML_MAX_DIMS - 1; i >= 1; --i) {
  1194. if (tensor->ne[i] > 1) {
  1195. return i + 1;
  1196. }
  1197. }
  1198. return 1;
  1199. }
  1200. enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
  1201. enum ggml_type wtype = GGML_TYPE_COUNT;
  1202. switch (ftype) {
  1203. case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break;
  1204. case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break;
  1205. case GGML_FTYPE_MOSTLY_BF16: wtype = GGML_TYPE_BF16; break;
  1206. case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break;
  1207. case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break;
  1208. case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
  1209. case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break;
  1210. case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break;
  1211. case GGML_FTYPE_MOSTLY_MXFP4: wtype = GGML_TYPE_MXFP4; break;
  1212. case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break;
  1213. case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break;
  1214. case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break;
  1215. case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break;
  1216. case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
  1217. case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break;
  1218. case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break;
  1219. case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break;
  1220. case GGML_FTYPE_MOSTLY_IQ1_S: wtype = GGML_TYPE_IQ1_S; break;
  1221. case GGML_FTYPE_MOSTLY_IQ1_M: wtype = GGML_TYPE_IQ1_M; break;
  1222. case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break;
  1223. case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break;
  1224. case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break;
  1225. case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break;
  1226. case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
  1227. case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
  1228. }
  1229. GGML_ASSERT(wtype != GGML_TYPE_COUNT);
  1230. return wtype;
  1231. }
  1232. size_t ggml_tensor_overhead(void) {
  1233. return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE;
  1234. }
  1235. bool ggml_is_transposed(const struct ggml_tensor * tensor) {
  1236. return tensor->nb[0] > tensor->nb[1];
  1237. }
  1238. static bool ggml_is_contiguous_n(const struct ggml_tensor * tensor, int n) {
  1239. size_t next_nb = ggml_type_size(tensor->type);
  1240. if (tensor->ne[0] != ggml_blck_size(tensor->type) && tensor->nb[0] != next_nb) {
  1241. return false;
  1242. }
  1243. next_nb *= tensor->ne[0]/ggml_blck_size(tensor->type);
  1244. for (int i = 1; i < GGML_MAX_DIMS; i++) {
  1245. if (tensor->ne[i] != 1) {
  1246. if (i > n) {
  1247. if (tensor->nb[i] != next_nb) {
  1248. return false;
  1249. }
  1250. next_nb *= tensor->ne[i];
  1251. } else {
  1252. // this dimension does not need to be contiguous
  1253. next_nb = tensor->ne[i]*tensor->nb[i];
  1254. }
  1255. }
  1256. }
  1257. return true;
  1258. }
  1259. bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
  1260. return ggml_is_contiguous_0(tensor);
  1261. }
  1262. bool ggml_is_contiguous_0(const struct ggml_tensor * tensor) {
  1263. return ggml_is_contiguous_n(tensor, 0);
  1264. }
  1265. bool ggml_is_contiguous_1(const struct ggml_tensor * tensor) {
  1266. return ggml_is_contiguous_n(tensor, 1);
  1267. }
  1268. bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) {
  1269. return ggml_is_contiguous_n(tensor, 2);
  1270. }
  1271. bool ggml_is_contiguously_allocated(const struct ggml_tensor * tensor) {
  1272. return ggml_nbytes(tensor) == ggml_nelements(tensor) * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
  1273. }
  1274. bool ggml_is_permuted(const struct ggml_tensor * tensor) {
  1275. static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
  1276. return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
  1277. }
  1278. bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor) {
  1279. return
  1280. tensor->nb[0] > tensor->nb[2] &&
  1281. tensor->nb[1] > tensor->nb[0] &&
  1282. tensor->nb[2] == ggml_type_size(tensor->type);
  1283. }
  1284. bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor) {
  1285. return
  1286. tensor->ne[0] == ggml_blck_size(tensor->type) ||
  1287. tensor->nb[0] == ggml_type_size(tensor->type);
  1288. }
  1289. static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
  1290. static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
  1291. return
  1292. tensor->nb[0] == ggml_type_size(tensor->type) &&
  1293. tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
  1294. tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
  1295. }
  1296. bool ggml_is_empty(const struct ggml_tensor * tensor) {
  1297. for (int i = 0; i < GGML_MAX_DIMS; ++i) {
  1298. if (tensor->ne[i] == 0) {
  1299. // empty if any dimension has no elements
  1300. return true;
  1301. }
  1302. }
  1303. return false;
  1304. }
  1305. bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
  1306. static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
  1307. return
  1308. (t0->ne[0] == t1->ne[0]) &&
  1309. (t0->ne[1] == t1->ne[1]) &&
  1310. (t0->ne[2] == t1->ne[2]) &&
  1311. (t0->ne[3] == t1->ne[3]);
  1312. }
  1313. bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
  1314. static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
  1315. return
  1316. (t0->nb[0] == t1->nb[0]) &&
  1317. (t0->nb[1] == t1->nb[1]) &&
  1318. (t0->nb[2] == t1->nb[2]) &&
  1319. (t0->nb[3] == t1->nb[3]);
  1320. }
  1321. // check if t1 can be represented as a repetition of t0
  1322. bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
  1323. static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
  1324. return ggml_is_empty(t0) ? ggml_is_empty(t1) :
  1325. (t1->ne[0]%t0->ne[0] == 0) &&
  1326. (t1->ne[1]%t0->ne[1] == 0) &&
  1327. (t1->ne[2]%t0->ne[2] == 0) &&
  1328. (t1->ne[3]%t0->ne[3] == 0);
  1329. }
  1330. static inline bool ggml_can_repeat_rows(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
  1331. static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
  1332. return (t0->ne[0] == t1->ne[0]) && ggml_can_repeat(t0, t1);
  1333. }
  1334. // assert that pointer is aligned to GGML_MEM_ALIGN
  1335. #define GGML_ASSERT_ALIGNED(ptr) \
  1336. GGML_ASSERT(((uintptr_t) (ptr))%GGML_MEM_ALIGN == 0)
  1337. ////////////////////////////////////////////////////////////////////////////////
  1338. struct ggml_context * ggml_init(struct ggml_init_params params) {
  1339. static bool is_first_call = true;
  1340. ggml_critical_section_start();
  1341. if (is_first_call) {
  1342. // initialize time system (required on Windows)
  1343. ggml_time_init();
  1344. is_first_call = false;
  1345. }
  1346. ggml_critical_section_end();
  1347. struct ggml_context * ctx = GGML_MALLOC(sizeof(struct ggml_context));
  1348. // allow to call ggml_init with 0 size
  1349. if (params.mem_size == 0) {
  1350. params.mem_size = GGML_MEM_ALIGN;
  1351. }
  1352. const size_t mem_size = params.mem_buffer ? params.mem_size : GGML_PAD(params.mem_size, GGML_MEM_ALIGN);
  1353. *ctx = (struct ggml_context) {
  1354. /*.mem_size =*/ mem_size,
  1355. /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : ggml_aligned_malloc(mem_size),
  1356. /*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
  1357. /*.no_alloc =*/ params.no_alloc,
  1358. /*.n_objects =*/ 0,
  1359. /*.objects_begin =*/ NULL,
  1360. /*.objects_end =*/ NULL,
  1361. };
  1362. GGML_ASSERT(ctx->mem_buffer != NULL);
  1363. GGML_ASSERT_ALIGNED(ctx->mem_buffer);
  1364. GGML_PRINT_DEBUG("%s: context initialized\n", __func__);
  1365. return ctx;
  1366. }
  1367. void ggml_reset(struct ggml_context * ctx) {
  1368. if (ctx == NULL) {
  1369. return;
  1370. }
  1371. ctx->n_objects = 0;
  1372. ctx->objects_begin = NULL;
  1373. ctx->objects_end = NULL;
  1374. }
  1375. void ggml_free(struct ggml_context * ctx) {
  1376. if (ctx == NULL) {
  1377. return;
  1378. }
  1379. if (ctx->mem_buffer_owned) {
  1380. ggml_aligned_free(ctx->mem_buffer, ctx->mem_size);
  1381. }
  1382. GGML_FREE(ctx);
  1383. }
  1384. size_t ggml_used_mem(const struct ggml_context * ctx) {
  1385. return ctx->objects_end == NULL ? 0 : ctx->objects_end->offs + ctx->objects_end->size;
  1386. }
  1387. bool ggml_get_no_alloc(struct ggml_context * ctx) {
  1388. return ctx->no_alloc;
  1389. }
  1390. void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc) {
  1391. ctx->no_alloc = no_alloc;
  1392. }
  1393. void * ggml_get_mem_buffer(const struct ggml_context * ctx) {
  1394. return ctx->mem_buffer;
  1395. }
  1396. size_t ggml_get_mem_size(const struct ggml_context * ctx) {
  1397. return ctx->mem_size;
  1398. }
  1399. size_t ggml_get_max_tensor_size(const struct ggml_context * ctx) {
  1400. size_t max_size = 0;
  1401. for (struct ggml_tensor * tensor = ggml_get_first_tensor(ctx); tensor != NULL; tensor = ggml_get_next_tensor(ctx, tensor)) {
  1402. size_t bytes = ggml_nbytes(tensor);
  1403. max_size = MAX(max_size, bytes);
  1404. }
  1405. return max_size;
  1406. }
  1407. ////////////////////////////////////////////////////////////////////////////////
  1408. static struct ggml_object * ggml_new_object(struct ggml_context * ctx, enum ggml_object_type type, size_t size) {
  1409. // always insert objects at the end of the context's memory pool
  1410. struct ggml_object * obj_cur = ctx->objects_end;
  1411. const size_t cur_offs = obj_cur == NULL ? 0 : obj_cur->offs;
  1412. const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size;
  1413. const size_t cur_end = cur_offs + cur_size;
  1414. // align to GGML_MEM_ALIGN
  1415. size_t size_needed = GGML_PAD(size, GGML_MEM_ALIGN);
  1416. char * const mem_buffer = ctx->mem_buffer;
  1417. struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end);
  1418. if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
  1419. GGML_LOG_WARN("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
  1420. __func__, cur_end + size_needed + GGML_OBJECT_SIZE, ctx->mem_size);
  1421. #ifndef NDEBUG
  1422. GGML_ABORT("not enough space in the context's memory pool");
  1423. #endif
  1424. return NULL;
  1425. }
  1426. *obj_new = (struct ggml_object) {
  1427. .offs = cur_end + GGML_OBJECT_SIZE,
  1428. .size = size_needed,
  1429. .next = NULL,
  1430. .type = type,
  1431. };
  1432. GGML_ASSERT_ALIGNED(mem_buffer + obj_new->offs);
  1433. if (obj_cur != NULL) {
  1434. obj_cur->next = obj_new;
  1435. } else {
  1436. // this is the first object in this context
  1437. ctx->objects_begin = obj_new;
  1438. }
  1439. ctx->objects_end = obj_new;
  1440. //printf("%s: inserted new object at %zu, size = %zu\n", __func__, cur_end, obj_new->size);
  1441. return obj_new;
  1442. }
  1443. static struct ggml_tensor * ggml_new_tensor_impl(
  1444. struct ggml_context * ctx,
  1445. enum ggml_type type,
  1446. int n_dims,
  1447. const int64_t * ne,
  1448. struct ggml_tensor * view_src,
  1449. size_t view_offs) {
  1450. GGML_ASSERT(type >= 0 && type < GGML_TYPE_COUNT);
  1451. GGML_ASSERT(n_dims >= 1 && n_dims <= GGML_MAX_DIMS);
  1452. // find the base tensor and absolute offset
  1453. if (view_src != NULL && view_src->view_src != NULL) {
  1454. view_offs += view_src->view_offs;
  1455. view_src = view_src->view_src;
  1456. }
  1457. size_t data_size = ggml_row_size(type, ne[0]);
  1458. for (int i = 1; i < n_dims; i++) {
  1459. data_size *= ne[i];
  1460. }
  1461. GGML_ASSERT(view_src == NULL || data_size == 0 || data_size + view_offs <= ggml_nbytes(view_src));
  1462. void * data = view_src != NULL ? view_src->data : NULL;
  1463. if (data != NULL) {
  1464. data = (char *) data + view_offs;
  1465. }
  1466. size_t obj_alloc_size = 0;
  1467. if (view_src == NULL && !ctx->no_alloc) {
  1468. // allocate tensor data in the context's memory pool
  1469. obj_alloc_size = data_size;
  1470. }
  1471. struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TYPE_TENSOR, GGML_TENSOR_SIZE + obj_alloc_size);
  1472. GGML_ASSERT(obj_new);
  1473. struct ggml_tensor * const result = (struct ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs);
  1474. *result = (struct ggml_tensor) {
  1475. /*.type =*/ type,
  1476. /*.buffer =*/ NULL,
  1477. /*.ne =*/ { 1, 1, 1, 1 },
  1478. /*.nb =*/ { 0, 0, 0, 0 },
  1479. /*.op =*/ GGML_OP_NONE,
  1480. /*.op_params =*/ { 0 },
  1481. /*.flags =*/ 0,
  1482. /*.src =*/ { NULL },
  1483. /*.view_src =*/ view_src,
  1484. /*.view_offs =*/ view_offs,
  1485. /*.data =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data,
  1486. /*.name =*/ { 0 },
  1487. /*.extra =*/ NULL,
  1488. /*.padding =*/ { 0 },
  1489. };
  1490. // TODO: this should not be needed as long as we don't rely on aligned SIMD loads
  1491. //GGML_ASSERT_ALIGNED(result->data);
  1492. for (int i = 0; i < n_dims; i++) {
  1493. result->ne[i] = ne[i];
  1494. }
  1495. result->nb[0] = ggml_type_size(type);
  1496. result->nb[1] = result->nb[0]*(result->ne[0]/ggml_blck_size(type));
  1497. for (int i = 2; i < GGML_MAX_DIMS; i++) {
  1498. result->nb[i] = result->nb[i - 1]*result->ne[i - 1];
  1499. }
  1500. ctx->n_objects++;
  1501. return result;
  1502. }
  1503. struct ggml_tensor * ggml_new_tensor(
  1504. struct ggml_context * ctx,
  1505. enum ggml_type type,
  1506. int n_dims,
  1507. const int64_t * ne) {
  1508. return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL, 0);
  1509. }
  1510. struct ggml_tensor * ggml_new_tensor_1d(
  1511. struct ggml_context * ctx,
  1512. enum ggml_type type,
  1513. int64_t ne0) {
  1514. return ggml_new_tensor(ctx, type, 1, &ne0);
  1515. }
  1516. struct ggml_tensor * ggml_new_tensor_2d(
  1517. struct ggml_context * ctx,
  1518. enum ggml_type type,
  1519. int64_t ne0,
  1520. int64_t ne1) {
  1521. const int64_t ne[2] = { ne0, ne1 };
  1522. return ggml_new_tensor(ctx, type, 2, ne);
  1523. }
  1524. struct ggml_tensor * ggml_new_tensor_3d(
  1525. struct ggml_context * ctx,
  1526. enum ggml_type type,
  1527. int64_t ne0,
  1528. int64_t ne1,
  1529. int64_t ne2) {
  1530. const int64_t ne[3] = { ne0, ne1, ne2 };
  1531. return ggml_new_tensor(ctx, type, 3, ne);
  1532. }
  1533. struct ggml_tensor * ggml_new_tensor_4d(
  1534. struct ggml_context * ctx,
  1535. enum ggml_type type,
  1536. int64_t ne0,
  1537. int64_t ne1,
  1538. int64_t ne2,
  1539. int64_t ne3) {
  1540. const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
  1541. return ggml_new_tensor(ctx, type, 4, ne);
  1542. }
  1543. void * ggml_new_buffer(struct ggml_context * ctx, size_t nbytes) {
  1544. struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_TYPE_WORK_BUFFER, nbytes);
  1545. return (uint8_t *)ctx->mem_buffer + obj->offs;
  1546. }
  1547. struct ggml_tensor * ggml_dup_tensor(struct ggml_context * ctx, const struct ggml_tensor * src) {
  1548. return ggml_new_tensor(ctx, src->type, GGML_MAX_DIMS, src->ne);
  1549. }
  1550. void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3) {
  1551. const int64_t ne2 = tensor->ne[2];
  1552. const int64_t ne1 = tensor->ne[1];
  1553. const int64_t ne0 = tensor->ne[0];
  1554. const int64_t i3_ = (i/(ne2*ne1*ne0));
  1555. const int64_t i2_ = (i - i3_*ne2*ne1*ne0)/(ne1*ne0);
  1556. const int64_t i1_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0)/ne0;
  1557. const int64_t i0_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0 - i1_*ne0);
  1558. if (i0) {
  1559. * i0 = i0_;
  1560. }
  1561. if (i1) {
  1562. * i1 = i1_;
  1563. }
  1564. if (i2) {
  1565. * i2 = i2_;
  1566. }
  1567. if (i3) {
  1568. * i3 = i3_;
  1569. }
  1570. }
  1571. void * ggml_get_data(const struct ggml_tensor * tensor) {
  1572. return tensor->data;
  1573. }
  1574. float * ggml_get_data_f32(const struct ggml_tensor * tensor) {
  1575. assert(tensor->type == GGML_TYPE_F32);
  1576. return (float *)(tensor->data);
  1577. }
  1578. enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) {
  1579. GGML_ASSERT(tensor->op == GGML_OP_UNARY);
  1580. return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0);
  1581. }
  1582. enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor) {
  1583. GGML_ASSERT(tensor->op == GGML_OP_GLU);
  1584. return (enum ggml_glu_op) ggml_get_op_params_i32(tensor, 0);
  1585. }
  1586. const char * ggml_get_name(const struct ggml_tensor * tensor) {
  1587. return tensor->name;
  1588. }
  1589. struct ggml_tensor * ggml_set_name(struct ggml_tensor * tensor, const char * name) {
  1590. size_t i;
  1591. for (i = 0; i < sizeof(tensor->name) - 1 && name[i] != '\0'; i++) {
  1592. tensor->name[i] = name[i];
  1593. }
  1594. tensor->name[i] = '\0';
  1595. return tensor;
  1596. }
  1597. struct ggml_tensor * ggml_format_name(struct ggml_tensor * tensor, const char * fmt, ...) {
  1598. va_list args;
  1599. va_start(args, fmt);
  1600. vsnprintf(tensor->name, sizeof(tensor->name), fmt, args);
  1601. va_end(args);
  1602. return tensor;
  1603. }
  1604. struct ggml_tensor * ggml_view_tensor(
  1605. struct ggml_context * ctx,
  1606. struct ggml_tensor * src) {
  1607. struct ggml_tensor * result = ggml_new_tensor_impl(ctx, src->type, GGML_MAX_DIMS, src->ne, src, 0);
  1608. ggml_format_name(result, "%s (view)", src->name);
  1609. for (int i = 0; i < GGML_MAX_DIMS; i++) {
  1610. result->nb[i] = src->nb[i];
  1611. }
  1612. return result;
  1613. }
  1614. struct ggml_tensor * ggml_get_first_tensor(const struct ggml_context * ctx) {
  1615. struct ggml_object * obj = ctx->objects_begin;
  1616. char * const mem_buffer = ctx->mem_buffer;
  1617. while (obj != NULL) {
  1618. if (obj->type == GGML_OBJECT_TYPE_TENSOR) {
  1619. return (struct ggml_tensor *)(mem_buffer + obj->offs);
  1620. }
  1621. obj = obj->next;
  1622. }
  1623. return NULL;
  1624. }
  1625. struct ggml_tensor * ggml_get_next_tensor(const struct ggml_context * ctx, struct ggml_tensor * tensor) {
  1626. struct ggml_object * obj = (struct ggml_object *) ((char *)tensor - GGML_OBJECT_SIZE);
  1627. obj = obj->next;
  1628. char * const mem_buffer = ctx->mem_buffer;
  1629. while (obj != NULL) {
  1630. if (obj->type == GGML_OBJECT_TYPE_TENSOR) {
  1631. return (struct ggml_tensor *)(mem_buffer + obj->offs);
  1632. }
  1633. obj = obj->next;
  1634. }
  1635. return NULL;
  1636. }
  1637. struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name) {
  1638. struct ggml_object * obj = ctx->objects_begin;
  1639. char * const mem_buffer = ctx->mem_buffer;
  1640. while (obj != NULL) {
  1641. if (obj->type == GGML_OBJECT_TYPE_TENSOR) {
  1642. struct ggml_tensor * cur = (struct ggml_tensor *)(mem_buffer + obj->offs);
  1643. if (strcmp(cur->name, name) == 0) {
  1644. return cur;
  1645. }
  1646. }
  1647. obj = obj->next;
  1648. }
  1649. return NULL;
  1650. }
  1651. ////////////////////////////////////////////////////////////////////////////////
  1652. // ggml_dup
  1653. static struct ggml_tensor * ggml_dup_impl(
  1654. struct ggml_context * ctx,
  1655. struct ggml_tensor * a,
  1656. bool inplace) {
  1657. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  1658. result->op = GGML_OP_DUP;
  1659. result->src[0] = a;
  1660. return result;
  1661. }
  1662. struct ggml_tensor * ggml_dup(
  1663. struct ggml_context * ctx,
  1664. struct ggml_tensor * a) {
  1665. return ggml_dup_impl(ctx, a, false);
  1666. }
  1667. struct ggml_tensor * ggml_dup_inplace(
  1668. struct ggml_context * ctx,
  1669. struct ggml_tensor * a) {
  1670. return ggml_dup_impl(ctx, a, true);
  1671. }
  1672. // ggml_add
  1673. static struct ggml_tensor * ggml_add_impl(
  1674. struct ggml_context * ctx,
  1675. struct ggml_tensor * a,
  1676. struct ggml_tensor * b,
  1677. bool inplace) {
  1678. GGML_ASSERT(ggml_can_repeat(b, a));
  1679. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  1680. result->op = GGML_OP_ADD;
  1681. result->src[0] = a;
  1682. result->src[1] = b;
  1683. return result;
  1684. }
  1685. struct ggml_tensor * ggml_add(
  1686. struct ggml_context * ctx,
  1687. struct ggml_tensor * a,
  1688. struct ggml_tensor * b) {
  1689. return ggml_add_impl(ctx, a, b, false);
  1690. }
  1691. struct ggml_tensor * ggml_add_inplace(
  1692. struct ggml_context * ctx,
  1693. struct ggml_tensor * a,
  1694. struct ggml_tensor * b) {
  1695. return ggml_add_impl(ctx, a, b, true);
  1696. }
  1697. // ggml_add_cast
  1698. static struct ggml_tensor * ggml_add_cast_impl(
  1699. struct ggml_context * ctx,
  1700. struct ggml_tensor * a,
  1701. struct ggml_tensor * b,
  1702. enum ggml_type type) {
  1703. // TODO: support less-strict constraint
  1704. // GGML_ASSERT(ggml_can_repeat(b, a));
  1705. GGML_ASSERT(ggml_can_repeat_rows(b, a));
  1706. // currently only supported for quantized input and f16
  1707. GGML_ASSERT(ggml_is_quantized(a->type) ||
  1708. a->type == GGML_TYPE_F16 ||
  1709. a->type == GGML_TYPE_BF16);
  1710. struct ggml_tensor * result = ggml_new_tensor(ctx, type, GGML_MAX_DIMS, a->ne);
  1711. result->op = GGML_OP_ADD;
  1712. result->src[0] = a;
  1713. result->src[1] = b;
  1714. return result;
  1715. }
  1716. struct ggml_tensor * ggml_add_cast(
  1717. struct ggml_context * ctx,
  1718. struct ggml_tensor * a,
  1719. struct ggml_tensor * b,
  1720. enum ggml_type type) {
  1721. return ggml_add_cast_impl(ctx, a, b, type);
  1722. }
  1723. struct ggml_tensor * ggml_add_id(
  1724. struct ggml_context * ctx,
  1725. struct ggml_tensor * a,
  1726. struct ggml_tensor * b,
  1727. struct ggml_tensor * ids) {
  1728. GGML_ASSERT(a->ne[0] == b->ne[0]);
  1729. GGML_ASSERT(a->ne[1] == ids->ne[0]);
  1730. GGML_ASSERT(a->ne[2] == ids->ne[1]);
  1731. GGML_ASSERT(ids->type == GGML_TYPE_I32);
  1732. struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
  1733. result->op = GGML_OP_ADD_ID;
  1734. result->src[0] = a;
  1735. result->src[1] = b;
  1736. result->src[2] = ids;
  1737. return result;
  1738. }
  1739. // ggml_add1
  1740. static struct ggml_tensor * ggml_add1_impl(
  1741. struct ggml_context * ctx,
  1742. struct ggml_tensor * a,
  1743. struct ggml_tensor * b,
  1744. bool inplace) {
  1745. GGML_ASSERT(ggml_is_scalar(b));
  1746. GGML_ASSERT(ggml_is_padded_1d(a));
  1747. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  1748. result->op = GGML_OP_ADD1;
  1749. result->src[0] = a;
  1750. result->src[1] = b;
  1751. return result;
  1752. }
  1753. struct ggml_tensor * ggml_add1(
  1754. struct ggml_context * ctx,
  1755. struct ggml_tensor * a,
  1756. struct ggml_tensor * b) {
  1757. return ggml_add1_impl(ctx, a, b, false);
  1758. }
  1759. struct ggml_tensor * ggml_add1_inplace(
  1760. struct ggml_context * ctx,
  1761. struct ggml_tensor * a,
  1762. struct ggml_tensor * b) {
  1763. return ggml_add1_impl(ctx, a, b, true);
  1764. }
  1765. // ggml_acc
  1766. static struct ggml_tensor * ggml_acc_impl(
  1767. struct ggml_context * ctx,
  1768. struct ggml_tensor * a,
  1769. struct ggml_tensor * b,
  1770. size_t nb1,
  1771. size_t nb2,
  1772. size_t nb3,
  1773. size_t offset,
  1774. bool inplace) {
  1775. GGML_ASSERT(ggml_nelements(b) <= ggml_nelements(a));
  1776. GGML_ASSERT(ggml_is_contiguous(a));
  1777. GGML_ASSERT(a->type == GGML_TYPE_F32);
  1778. GGML_ASSERT(b->type == GGML_TYPE_F32);
  1779. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  1780. int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 };
  1781. ggml_set_op_params(result, params, sizeof(params));
  1782. result->op = GGML_OP_ACC;
  1783. result->src[0] = a;
  1784. result->src[1] = b;
  1785. return result;
  1786. }
  1787. struct ggml_tensor * ggml_acc(
  1788. struct ggml_context * ctx,
  1789. struct ggml_tensor * a,
  1790. struct ggml_tensor * b,
  1791. size_t nb1,
  1792. size_t nb2,
  1793. size_t nb3,
  1794. size_t offset) {
  1795. return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
  1796. }
  1797. struct ggml_tensor * ggml_acc_inplace(
  1798. struct ggml_context * ctx,
  1799. struct ggml_tensor * a,
  1800. struct ggml_tensor * b,
  1801. size_t nb1,
  1802. size_t nb2,
  1803. size_t nb3,
  1804. size_t offset) {
  1805. return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
  1806. }
  1807. // ggml_sub
  1808. static struct ggml_tensor * ggml_sub_impl(
  1809. struct ggml_context * ctx,
  1810. struct ggml_tensor * a,
  1811. struct ggml_tensor * b,
  1812. bool inplace) {
  1813. GGML_ASSERT(ggml_can_repeat(b, a));
  1814. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  1815. result->op = GGML_OP_SUB;
  1816. result->src[0] = a;
  1817. result->src[1] = b;
  1818. return result;
  1819. }
  1820. struct ggml_tensor * ggml_sub(
  1821. struct ggml_context * ctx,
  1822. struct ggml_tensor * a,
  1823. struct ggml_tensor * b) {
  1824. return ggml_sub_impl(ctx, a, b, false);
  1825. }
  1826. struct ggml_tensor * ggml_sub_inplace(
  1827. struct ggml_context * ctx,
  1828. struct ggml_tensor * a,
  1829. struct ggml_tensor * b) {
  1830. return ggml_sub_impl(ctx, a, b, true);
  1831. }
  1832. // ggml_mul
  1833. static struct ggml_tensor * ggml_mul_impl(
  1834. struct ggml_context * ctx,
  1835. struct ggml_tensor * a,
  1836. struct ggml_tensor * b,
  1837. bool inplace) {
  1838. GGML_ASSERT(ggml_can_repeat(b, a));
  1839. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  1840. result->op = GGML_OP_MUL;
  1841. result->src[0] = a;
  1842. result->src[1] = b;
  1843. return result;
  1844. }
  1845. struct ggml_tensor * ggml_mul(
  1846. struct ggml_context * ctx,
  1847. struct ggml_tensor * a,
  1848. struct ggml_tensor * b) {
  1849. return ggml_mul_impl(ctx, a, b, false);
  1850. }
  1851. struct ggml_tensor * ggml_mul_inplace(
  1852. struct ggml_context * ctx,
  1853. struct ggml_tensor * a,
  1854. struct ggml_tensor * b) {
  1855. return ggml_mul_impl(ctx, a, b, true);
  1856. }
  1857. // ggml_div
  1858. static struct ggml_tensor * ggml_div_impl(
  1859. struct ggml_context * ctx,
  1860. struct ggml_tensor * a,
  1861. struct ggml_tensor * b,
  1862. bool inplace) {
  1863. GGML_ASSERT(ggml_can_repeat(b, a));
  1864. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  1865. result->op = GGML_OP_DIV;
  1866. result->src[0] = a;
  1867. result->src[1] = b;
  1868. return result;
  1869. }
  1870. struct ggml_tensor * ggml_div(
  1871. struct ggml_context * ctx,
  1872. struct ggml_tensor * a,
  1873. struct ggml_tensor * b) {
  1874. return ggml_div_impl(ctx, a, b, false);
  1875. }
  1876. struct ggml_tensor * ggml_div_inplace(
  1877. struct ggml_context * ctx,
  1878. struct ggml_tensor * a,
  1879. struct ggml_tensor * b) {
  1880. return ggml_div_impl(ctx, a, b, true);
  1881. }
  1882. // ggml_sqr
  1883. static struct ggml_tensor * ggml_sqr_impl(
  1884. struct ggml_context * ctx,
  1885. struct ggml_tensor * a,
  1886. bool inplace) {
  1887. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  1888. result->op = GGML_OP_SQR;
  1889. result->src[0] = a;
  1890. return result;
  1891. }
  1892. struct ggml_tensor * ggml_sqr(
  1893. struct ggml_context * ctx,
  1894. struct ggml_tensor * a) {
  1895. return ggml_sqr_impl(ctx, a, false);
  1896. }
  1897. struct ggml_tensor * ggml_sqr_inplace(
  1898. struct ggml_context * ctx,
  1899. struct ggml_tensor * a) {
  1900. return ggml_sqr_impl(ctx, a, true);
  1901. }
  1902. // ggml_sqrt
  1903. static struct ggml_tensor * ggml_sqrt_impl(
  1904. struct ggml_context * ctx,
  1905. struct ggml_tensor * a,
  1906. bool inplace) {
  1907. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  1908. result->op = GGML_OP_SQRT;
  1909. result->src[0] = a;
  1910. return result;
  1911. }
  1912. struct ggml_tensor * ggml_sqrt(
  1913. struct ggml_context * ctx,
  1914. struct ggml_tensor * a) {
  1915. return ggml_sqrt_impl(ctx, a, false);
  1916. }
  1917. struct ggml_tensor * ggml_sqrt_inplace(
  1918. struct ggml_context * ctx,
  1919. struct ggml_tensor * a) {
  1920. return ggml_sqrt_impl(ctx, a, true);
  1921. }
  1922. // ggml_log
  1923. static struct ggml_tensor * ggml_log_impl(
  1924. struct ggml_context * ctx,
  1925. struct ggml_tensor * a,
  1926. bool inplace) {
  1927. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  1928. result->op = GGML_OP_LOG;
  1929. result->src[0] = a;
  1930. return result;
  1931. }
  1932. struct ggml_tensor * ggml_log(
  1933. struct ggml_context * ctx,
  1934. struct ggml_tensor * a) {
  1935. return ggml_log_impl(ctx, a, false);
  1936. }
  1937. struct ggml_tensor * ggml_log_inplace(
  1938. struct ggml_context * ctx,
  1939. struct ggml_tensor * a) {
  1940. return ggml_log_impl(ctx, a, true);
  1941. }
  1942. struct ggml_tensor * ggml_expm1(
  1943. struct ggml_context * ctx,
  1944. struct ggml_tensor * a) {
  1945. return ggml_unary(ctx, a, GGML_UNARY_OP_EXPM1);
  1946. }
  1947. struct ggml_tensor * ggml_expm1_inplace(
  1948. struct ggml_context * ctx,
  1949. struct ggml_tensor * a) {
  1950. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXPM1);
  1951. }
  1952. struct ggml_tensor * ggml_softplus(
  1953. struct ggml_context * ctx,
  1954. struct ggml_tensor * a) {
  1955. return ggml_unary(ctx, a, GGML_UNARY_OP_SOFTPLUS);
  1956. }
  1957. struct ggml_tensor * ggml_softplus_inplace(
  1958. struct ggml_context * ctx,
  1959. struct ggml_tensor * a) {
  1960. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SOFTPLUS);
  1961. }
  1962. // ggml_sin
  1963. static struct ggml_tensor * ggml_sin_impl(
  1964. struct ggml_context * ctx,
  1965. struct ggml_tensor * a,
  1966. bool inplace) {
  1967. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  1968. result->op = GGML_OP_SIN;
  1969. result->src[0] = a;
  1970. return result;
  1971. }
  1972. struct ggml_tensor * ggml_sin(
  1973. struct ggml_context * ctx,
  1974. struct ggml_tensor * a) {
  1975. return ggml_sin_impl(ctx, a, false);
  1976. }
  1977. struct ggml_tensor * ggml_sin_inplace(
  1978. struct ggml_context * ctx,
  1979. struct ggml_tensor * a) {
  1980. return ggml_sin_impl(ctx, a, true);
  1981. }
  1982. // ggml_cos
  1983. static struct ggml_tensor * ggml_cos_impl(
  1984. struct ggml_context * ctx,
  1985. struct ggml_tensor * a,
  1986. bool inplace) {
  1987. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  1988. result->op = GGML_OP_COS;
  1989. result->src[0] = a;
  1990. return result;
  1991. }
  1992. struct ggml_tensor * ggml_cos(
  1993. struct ggml_context * ctx,
  1994. struct ggml_tensor * a) {
  1995. return ggml_cos_impl(ctx, a, false);
  1996. }
  1997. struct ggml_tensor * ggml_cos_inplace(
  1998. struct ggml_context * ctx,
  1999. struct ggml_tensor * a) {
  2000. return ggml_cos_impl(ctx, a, true);
  2001. }
  2002. // ggml_sum
  2003. struct ggml_tensor * ggml_sum(
  2004. struct ggml_context * ctx,
  2005. struct ggml_tensor * a) {
  2006. struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);
  2007. result->op = GGML_OP_SUM;
  2008. result->src[0] = a;
  2009. return result;
  2010. }
  2011. // ggml_sum_rows
  2012. struct ggml_tensor * ggml_sum_rows(
  2013. struct ggml_context * ctx,
  2014. struct ggml_tensor * a) {
  2015. int64_t ne[GGML_MAX_DIMS] = { 1 };
  2016. for (int i = 1; i < GGML_MAX_DIMS; ++i) {
  2017. ne[i] = a->ne[i];
  2018. }
  2019. struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);
  2020. result->op = GGML_OP_SUM_ROWS;
  2021. result->src[0] = a;
  2022. return result;
  2023. }
  2024. // ggml_cumsum
  2025. struct ggml_tensor * ggml_cumsum(
  2026. struct ggml_context * ctx,
  2027. struct ggml_tensor * a) {
  2028. GGML_ASSERT(a->type == GGML_TYPE_F32);
  2029. struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
  2030. result->op = GGML_OP_CUMSUM;
  2031. result->src[0] = a;
  2032. return result;
  2033. }
  2034. // ggml_mean
  2035. struct ggml_tensor * ggml_mean(
  2036. struct ggml_context * ctx,
  2037. struct ggml_tensor * a) {
  2038. int64_t ne[4] = { 1, a->ne[1], a->ne[2], a->ne[3] };
  2039. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
  2040. result->op = GGML_OP_MEAN;
  2041. result->src[0] = a;
  2042. return result;
  2043. }
  2044. // ggml_argmax
  2045. struct ggml_tensor * ggml_argmax(
  2046. struct ggml_context * ctx,
  2047. struct ggml_tensor * a) {
  2048. GGML_ASSERT(ggml_is_matrix(a));
  2049. GGML_ASSERT(a->ne[0] <= INT32_MAX);
  2050. struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, a->ne[1]);
  2051. result->op = GGML_OP_ARGMAX;
  2052. result->src[0] = a;
  2053. return result;
  2054. }
  2055. // ggml_count_equal
  2056. struct ggml_tensor * ggml_count_equal(
  2057. struct ggml_context * ctx,
  2058. struct ggml_tensor * a,
  2059. struct ggml_tensor * b) {
  2060. GGML_ASSERT(ggml_are_same_shape(a, b));
  2061. struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, 1);
  2062. result->op = GGML_OP_COUNT_EQUAL;
  2063. result->src[0] = a;
  2064. result->src[1] = b;
  2065. return result;
  2066. }
  2067. // ggml_repeat
  2068. struct ggml_tensor * ggml_repeat(
  2069. struct ggml_context * ctx,
  2070. struct ggml_tensor * a,
  2071. struct ggml_tensor * b) {
  2072. GGML_ASSERT(ggml_can_repeat(a, b));
  2073. struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, b->ne);
  2074. result->op = GGML_OP_REPEAT;
  2075. result->src[0] = a;
  2076. return result;
  2077. }
  2078. struct ggml_tensor * ggml_repeat_4d(
  2079. struct ggml_context * ctx,
  2080. struct ggml_tensor * a,
  2081. int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
  2082. const bool can_repeat = ggml_is_empty(a) || (
  2083. (ne0 % a->ne[0] == 0) &&
  2084. (ne1 % a->ne[1] == 0) &&
  2085. (ne2 % a->ne[2] == 0) &&
  2086. (ne3 % a->ne[3] == 0)
  2087. );
  2088. GGML_ASSERT(can_repeat);
  2089. struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
  2090. result->op = GGML_OP_REPEAT;
  2091. result->src[0] = a;
  2092. return result;
  2093. }
  2094. // ggml_repeat_back
  2095. struct ggml_tensor * ggml_repeat_back(
  2096. struct ggml_context * ctx,
  2097. struct ggml_tensor * a,
  2098. struct ggml_tensor * b) {
  2099. GGML_ASSERT(ggml_can_repeat(b, a));
  2100. struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, b->ne);
  2101. result->op = GGML_OP_REPEAT_BACK;
  2102. result->src[0] = a;
  2103. return result;
  2104. }
  2105. // ggml_concat
  2106. struct ggml_tensor * ggml_concat(
  2107. struct ggml_context * ctx,
  2108. struct ggml_tensor * a,
  2109. struct ggml_tensor * b,
  2110. int dim) {
  2111. GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
  2112. GGML_ASSERT(a->type == b->type);
  2113. int64_t ne[GGML_MAX_DIMS];
  2114. for (int d = 0; d < GGML_MAX_DIMS; ++d) {
  2115. if (d == dim) {
  2116. ne[d] = a->ne[d] + b->ne[d];
  2117. continue;
  2118. }
  2119. GGML_ASSERT(a->ne[d] == b->ne[d]);
  2120. ne[d] = a->ne[d];
  2121. }
  2122. struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);
  2123. ggml_set_op_params_i32(result, 0, dim);
  2124. result->op = GGML_OP_CONCAT;
  2125. result->src[0] = a;
  2126. result->src[1] = b;
  2127. return result;
  2128. }
  2129. // ggml_abs
  2130. struct ggml_tensor * ggml_abs(
  2131. struct ggml_context * ctx,
  2132. struct ggml_tensor * a) {
  2133. return ggml_unary(ctx, a, GGML_UNARY_OP_ABS);
  2134. }
  2135. struct ggml_tensor * ggml_abs_inplace(
  2136. struct ggml_context * ctx,
  2137. struct ggml_tensor * a) {
  2138. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ABS);
  2139. }
  2140. // ggml_sgn
  2141. struct ggml_tensor * ggml_sgn(
  2142. struct ggml_context * ctx,
  2143. struct ggml_tensor * a) {
  2144. return ggml_unary(ctx, a, GGML_UNARY_OP_SGN);
  2145. }
  2146. struct ggml_tensor * ggml_sgn_inplace(
  2147. struct ggml_context * ctx,
  2148. struct ggml_tensor * a) {
  2149. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SGN);
  2150. }
  2151. // ggml_neg
  2152. struct ggml_tensor * ggml_neg(
  2153. struct ggml_context * ctx,
  2154. struct ggml_tensor * a) {
  2155. return ggml_unary(ctx, a, GGML_UNARY_OP_NEG);
  2156. }
  2157. struct ggml_tensor * ggml_neg_inplace(
  2158. struct ggml_context * ctx,
  2159. struct ggml_tensor * a) {
  2160. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_NEG);
  2161. }
  2162. // ggml_step
  2163. struct ggml_tensor * ggml_step(
  2164. struct ggml_context * ctx,
  2165. struct ggml_tensor * a) {
  2166. return ggml_unary(ctx, a, GGML_UNARY_OP_STEP);
  2167. }
  2168. struct ggml_tensor * ggml_step_inplace(
  2169. struct ggml_context * ctx,
  2170. struct ggml_tensor * a) {
  2171. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_STEP);
  2172. }
  2173. // ggml_tanh
  2174. struct ggml_tensor * ggml_tanh(
  2175. struct ggml_context * ctx,
  2176. struct ggml_tensor * a) {
  2177. return ggml_unary(ctx, a, GGML_UNARY_OP_TANH);
  2178. }
  2179. struct ggml_tensor * ggml_tanh_inplace(
  2180. struct ggml_context * ctx,
  2181. struct ggml_tensor * a) {
  2182. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_TANH);
  2183. }
  2184. // ggml_elu
  2185. struct ggml_tensor * ggml_elu(
  2186. struct ggml_context * ctx,
  2187. struct ggml_tensor * a) {
  2188. return ggml_unary(ctx, a, GGML_UNARY_OP_ELU);
  2189. }
  2190. struct ggml_tensor * ggml_elu_inplace(
  2191. struct ggml_context * ctx,
  2192. struct ggml_tensor * a) {
  2193. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ELU);
  2194. }
  2195. // ggml_relu
  2196. struct ggml_tensor * ggml_relu(
  2197. struct ggml_context * ctx,
  2198. struct ggml_tensor * a) {
  2199. return ggml_unary(ctx, a, GGML_UNARY_OP_RELU);
  2200. }
  2201. struct ggml_tensor * ggml_relu_inplace(
  2202. struct ggml_context * ctx,
  2203. struct ggml_tensor * a) {
  2204. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU);
  2205. }
  2206. // ggml_leaky_relu
  2207. struct ggml_tensor * ggml_leaky_relu(
  2208. struct ggml_context * ctx,
  2209. struct ggml_tensor * a,
  2210. float negative_slope,
  2211. bool inplace) {
  2212. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  2213. ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
  2214. result->op = GGML_OP_LEAKY_RELU;
  2215. result->src[0] = a;
  2216. return result;
  2217. }
  2218. // ggml_sigmoid
  2219. struct ggml_tensor * ggml_sigmoid(
  2220. struct ggml_context * ctx,
  2221. struct ggml_tensor * a) {
  2222. return ggml_unary(ctx, a, GGML_UNARY_OP_SIGMOID);
  2223. }
  2224. struct ggml_tensor * ggml_sigmoid_inplace(
  2225. struct ggml_context * ctx,
  2226. struct ggml_tensor * a) {
  2227. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SIGMOID);
  2228. }
  2229. // ggml_gelu
  2230. struct ggml_tensor * ggml_gelu(
  2231. struct ggml_context * ctx,
  2232. struct ggml_tensor * a) {
  2233. return ggml_unary(ctx, a, GGML_UNARY_OP_GELU);
  2234. }
  2235. struct ggml_tensor * ggml_gelu_inplace(
  2236. struct ggml_context * ctx,
  2237. struct ggml_tensor * a) {
  2238. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU);
  2239. }
  2240. // ggml_gelu_erf
  2241. struct ggml_tensor * ggml_gelu_erf(
  2242. struct ggml_context * ctx,
  2243. struct ggml_tensor * a) {
  2244. return ggml_unary(ctx, a, GGML_UNARY_OP_GELU_ERF);
  2245. }
  2246. struct ggml_tensor * ggml_gelu_erf_inplace(
  2247. struct ggml_context * ctx,
  2248. struct ggml_tensor * a) {
  2249. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU_ERF);
  2250. }
  2251. // ggml_gelu_quick
  2252. struct ggml_tensor * ggml_gelu_quick(
  2253. struct ggml_context * ctx,
  2254. struct ggml_tensor * a) {
  2255. return ggml_unary(ctx, a, GGML_UNARY_OP_GELU_QUICK);
  2256. }
  2257. struct ggml_tensor * ggml_gelu_quick_inplace(
  2258. struct ggml_context * ctx,
  2259. struct ggml_tensor * a) {
  2260. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU_QUICK);
  2261. }
  2262. // ggml_silu
  2263. struct ggml_tensor * ggml_silu(
  2264. struct ggml_context * ctx,
  2265. struct ggml_tensor * a) {
  2266. return ggml_unary(ctx, a, GGML_UNARY_OP_SILU);
  2267. }
  2268. struct ggml_tensor * ggml_silu_inplace(
  2269. struct ggml_context * ctx,
  2270. struct ggml_tensor * a) {
  2271. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU);
  2272. }
  2273. // ggml_xielu
  2274. struct ggml_tensor * ggml_xielu(
  2275. struct ggml_context * ctx,
  2276. struct ggml_tensor * a,
  2277. float alpha_n,
  2278. float alpha_p,
  2279. float beta,
  2280. float eps) {
  2281. struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
  2282. ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_XIELU);
  2283. ggml_set_op_params_f32(result, 1, beta + ggml_compute_softplus_f32(alpha_n));
  2284. ggml_set_op_params_f32(result, 2, ggml_compute_softplus_f32(alpha_p));
  2285. ggml_set_op_params_f32(result, 3, beta);
  2286. ggml_set_op_params_f32(result, 4, eps);
  2287. result->op = GGML_OP_UNARY;
  2288. result->src[0] = a;
  2289. return result;
  2290. }
  2291. // ggml_silu_back
  2292. struct ggml_tensor * ggml_silu_back(
  2293. struct ggml_context * ctx,
  2294. struct ggml_tensor * a,
  2295. struct ggml_tensor * b) {
  2296. struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
  2297. result->op = GGML_OP_SILU_BACK;
  2298. result->src[0] = a;
  2299. result->src[1] = b;
  2300. return result;
  2301. }
  2302. // ggml hardswish
  2303. struct ggml_tensor * ggml_hardswish(
  2304. struct ggml_context * ctx,
  2305. struct ggml_tensor * a) {
  2306. return ggml_unary(ctx, a, GGML_UNARY_OP_HARDSWISH);
  2307. }
  2308. // ggml hardsigmoid
  2309. struct ggml_tensor * ggml_hardsigmoid(
  2310. struct ggml_context * ctx,
  2311. struct ggml_tensor * a) {
  2312. return ggml_unary(ctx, a, GGML_UNARY_OP_HARDSIGMOID);
  2313. }
  2314. // ggml exp
  2315. struct ggml_tensor * ggml_exp(
  2316. struct ggml_context * ctx,
  2317. struct ggml_tensor * a) {
  2318. return ggml_unary(ctx, a, GGML_UNARY_OP_EXP);
  2319. }
  2320. struct ggml_tensor * ggml_exp_inplace(
  2321. struct ggml_context * ctx,
  2322. struct ggml_tensor * a) {
  2323. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP);
  2324. }
  2325. // ggml_glu
  2326. static struct ggml_tensor * ggml_glu_impl(
  2327. struct ggml_context * ctx,
  2328. struct ggml_tensor * a,
  2329. struct ggml_tensor * b,
  2330. enum ggml_glu_op op,
  2331. bool swapped) {
  2332. GGML_ASSERT(ggml_is_contiguous_1(a));
  2333. if (b) {
  2334. GGML_ASSERT(ggml_is_contiguous_1(b));
  2335. GGML_ASSERT(ggml_are_same_shape(a, b));
  2336. GGML_ASSERT(a->type == b->type);
  2337. }
  2338. int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i];
  2339. struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b ? a->ne : ne, NULL, 0);
  2340. ggml_set_op_params_i32(result, 0, (int32_t) op);
  2341. ggml_set_op_params_i32(result, 1, (int32_t) swapped);
  2342. result->op = GGML_OP_GLU;
  2343. result->src[0] = a;
  2344. result->src[1] = b;
  2345. return result;
  2346. }
  2347. // ggml_floor
  2348. struct ggml_tensor * ggml_floor(
  2349. struct ggml_context * ctx,
  2350. struct ggml_tensor * a) {
  2351. return ggml_unary(ctx, a, GGML_UNARY_OP_FLOOR);
  2352. }
  2353. struct ggml_tensor * ggml_floor_inplace(
  2354. struct ggml_context * ctx,
  2355. struct ggml_tensor * a) {
  2356. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_FLOOR);
  2357. }
  2358. // ggml_ceil
  2359. struct ggml_tensor * ggml_ceil(
  2360. struct ggml_context * ctx,
  2361. struct ggml_tensor * a) {
  2362. return ggml_unary(ctx, a, GGML_UNARY_OP_CEIL);
  2363. }
  2364. struct ggml_tensor * ggml_ceil_inplace(
  2365. struct ggml_context * ctx,
  2366. struct ggml_tensor * a) {
  2367. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_CEIL);
  2368. }
  2369. //ggml_round
  2370. struct ggml_tensor * ggml_round(
  2371. struct ggml_context * ctx,
  2372. struct ggml_tensor * a) {
  2373. return ggml_unary(ctx, a, GGML_UNARY_OP_ROUND);
  2374. }
  2375. struct ggml_tensor * ggml_round_inplace(
  2376. struct ggml_context * ctx,
  2377. struct ggml_tensor * a) {
  2378. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ROUND);
  2379. }
  2380. //ggml_trunc
  2381. struct ggml_tensor * ggml_trunc(
  2382. struct ggml_context * ctx,
  2383. struct ggml_tensor * a) {
  2384. return ggml_unary(ctx, a, GGML_UNARY_OP_TRUNC);
  2385. }
  2386. struct ggml_tensor * ggml_trunc_inplace(
  2387. struct ggml_context * ctx,
  2388. struct ggml_tensor * a) {
  2389. return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_TRUNC);
  2390. }
  2391. struct ggml_tensor * ggml_glu(
  2392. struct ggml_context * ctx,
  2393. struct ggml_tensor * a,
  2394. enum ggml_glu_op op,
  2395. bool swapped) {
  2396. return ggml_glu_impl(ctx, a, NULL, op, swapped);
  2397. }
  2398. struct ggml_tensor * ggml_glu_split(
  2399. struct ggml_context * ctx,
  2400. struct ggml_tensor * a,
  2401. struct ggml_tensor * b,
  2402. enum ggml_glu_op op) {
  2403. return ggml_glu_impl(ctx, a, b, op, false);
  2404. }
  2405. // ggml_reglu
  2406. struct ggml_tensor * ggml_reglu(
  2407. struct ggml_context * ctx,
  2408. struct ggml_tensor * a) {
  2409. return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, false);
  2410. }
  2411. struct ggml_tensor * ggml_reglu_swapped(
  2412. struct ggml_context * ctx,
  2413. struct ggml_tensor * a) {
  2414. return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, true);
  2415. }
  2416. struct ggml_tensor * ggml_reglu_split(
  2417. struct ggml_context * ctx,
  2418. struct ggml_tensor * a,
  2419. struct ggml_tensor * b) {
  2420. return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_REGLU, false);
  2421. }
  2422. // ggml_geglu
  2423. struct ggml_tensor * ggml_geglu(
  2424. struct ggml_context * ctx,
  2425. struct ggml_tensor * a) {
  2426. return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, false);
  2427. }
  2428. struct ggml_tensor * ggml_geglu_swapped(
  2429. struct ggml_context * ctx,
  2430. struct ggml_tensor * a) {
  2431. return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, true);
  2432. }
  2433. struct ggml_tensor * ggml_geglu_split(
  2434. struct ggml_context * ctx,
  2435. struct ggml_tensor * a,
  2436. struct ggml_tensor * b) {
  2437. return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU, false);
  2438. }
  2439. // ggml_swiglu
  2440. struct ggml_tensor * ggml_swiglu(
  2441. struct ggml_context * ctx,
  2442. struct ggml_tensor * a) {
  2443. return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, false);
  2444. }
  2445. struct ggml_tensor * ggml_swiglu_swapped(
  2446. struct ggml_context * ctx,
  2447. struct ggml_tensor * a) {
  2448. return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, true);
  2449. }
  2450. struct ggml_tensor * ggml_swiglu_split(
  2451. struct ggml_context * ctx,
  2452. struct ggml_tensor * a,
  2453. struct ggml_tensor * b) {
  2454. return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU, false);
  2455. }
  2456. // ggml_geglu_erf
  2457. struct ggml_tensor * ggml_geglu_erf(
  2458. struct ggml_context * ctx,
  2459. struct ggml_tensor * a) {
  2460. return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, false);
  2461. }
  2462. struct ggml_tensor * ggml_geglu_erf_swapped(
  2463. struct ggml_context * ctx,
  2464. struct ggml_tensor * a) {
  2465. return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, true);
  2466. }
  2467. struct ggml_tensor * ggml_geglu_erf_split(
  2468. struct ggml_context * ctx,
  2469. struct ggml_tensor * a,
  2470. struct ggml_tensor * b) {
  2471. return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_ERF, false);
  2472. }
  2473. // ggml_geglu_quick
  2474. struct ggml_tensor * ggml_geglu_quick(
  2475. struct ggml_context * ctx,
  2476. struct ggml_tensor * a) {
  2477. return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, false);
  2478. }
  2479. struct ggml_tensor * ggml_geglu_quick_swapped(
  2480. struct ggml_context * ctx,
  2481. struct ggml_tensor * a) {
  2482. return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, true);
  2483. }
  2484. struct ggml_tensor * ggml_geglu_quick_split(
  2485. struct ggml_context * ctx,
  2486. struct ggml_tensor * a,
  2487. struct ggml_tensor * b) {
  2488. return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_QUICK, false);
  2489. }
  2490. struct ggml_tensor * ggml_swiglu_oai(
  2491. struct ggml_context * ctx,
  2492. struct ggml_tensor * a,
  2493. struct ggml_tensor * b,
  2494. float alpha,
  2495. float limit) {
  2496. struct ggml_tensor * result = ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU_OAI, false);
  2497. ggml_set_op_params_f32(result, 2, alpha);
  2498. ggml_set_op_params_f32(result, 3, limit);
  2499. return result;
  2500. }
  2501. // ggml_norm
  2502. static struct ggml_tensor * ggml_norm_impl(
  2503. struct ggml_context * ctx,
  2504. struct ggml_tensor * a,
  2505. float eps,
  2506. bool inplace) {
  2507. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  2508. ggml_set_op_params(result, &eps, sizeof(eps));
  2509. result->op = GGML_OP_NORM;
  2510. result->src[0] = a;
  2511. return result;
  2512. }
  2513. struct ggml_tensor * ggml_norm(
  2514. struct ggml_context * ctx,
  2515. struct ggml_tensor * a,
  2516. float eps) {
  2517. return ggml_norm_impl(ctx, a, eps, false);
  2518. }
  2519. struct ggml_tensor * ggml_norm_inplace(
  2520. struct ggml_context * ctx,
  2521. struct ggml_tensor * a,
  2522. float eps) {
  2523. return ggml_norm_impl(ctx, a, eps, true);
  2524. }
  2525. // ggml_rms_norm
  2526. static struct ggml_tensor * ggml_rms_norm_impl(
  2527. struct ggml_context * ctx,
  2528. struct ggml_tensor * a,
  2529. float eps,
  2530. bool inplace) {
  2531. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  2532. ggml_set_op_params(result, &eps, sizeof(eps));
  2533. result->op = GGML_OP_RMS_NORM;
  2534. result->src[0] = a;
  2535. return result;
  2536. }
  2537. struct ggml_tensor * ggml_rms_norm(
  2538. struct ggml_context * ctx,
  2539. struct ggml_tensor * a,
  2540. float eps) {
  2541. return ggml_rms_norm_impl(ctx, a, eps, false);
  2542. }
  2543. struct ggml_tensor * ggml_rms_norm_inplace(
  2544. struct ggml_context * ctx,
  2545. struct ggml_tensor * a,
  2546. float eps) {
  2547. return ggml_rms_norm_impl(ctx, a, eps, true);
  2548. }
  2549. // ggml_rms_norm_back
  2550. struct ggml_tensor * ggml_rms_norm_back(
  2551. struct ggml_context * ctx,
  2552. struct ggml_tensor * a,
  2553. struct ggml_tensor * b,
  2554. float eps) {
  2555. struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
  2556. ggml_set_op_params(result, &eps, sizeof(eps));
  2557. result->op = GGML_OP_RMS_NORM_BACK;
  2558. result->src[0] = a;
  2559. result->src[1] = b;
  2560. return result;
  2561. }
  2562. // ggml_group_norm
  2563. static struct ggml_tensor * ggml_group_norm_impl(
  2564. struct ggml_context * ctx,
  2565. struct ggml_tensor * a,
  2566. int n_groups,
  2567. float eps,
  2568. bool inplace) {
  2569. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  2570. ggml_set_op_params_i32(result, 0, n_groups);
  2571. ggml_set_op_params_f32(result, 1, eps);
  2572. result->op = GGML_OP_GROUP_NORM;
  2573. result->src[0] = a;
  2574. return result;
  2575. }
  2576. struct ggml_tensor * ggml_group_norm(
  2577. struct ggml_context * ctx,
  2578. struct ggml_tensor * a,
  2579. int n_groups,
  2580. float eps) {
  2581. return ggml_group_norm_impl(ctx, a, n_groups, eps, false);
  2582. }
  2583. struct ggml_tensor * ggml_group_norm_inplace(
  2584. struct ggml_context * ctx,
  2585. struct ggml_tensor * a,
  2586. int n_groups,
  2587. float eps) {
  2588. return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
  2589. }
  2590. // ggml_l2_norm
  2591. static struct ggml_tensor * ggml_l2_norm_impl(
  2592. struct ggml_context * ctx,
  2593. struct ggml_tensor * a,
  2594. float eps,
  2595. bool inplace) {
  2596. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  2597. ggml_set_op_params_f32(result, 0, eps);
  2598. result->op = GGML_OP_L2_NORM;
  2599. result->src[0] = a;
  2600. return result;
  2601. }
  2602. struct ggml_tensor * ggml_l2_norm(
  2603. struct ggml_context * ctx,
  2604. struct ggml_tensor * a,
  2605. float eps) {
  2606. return ggml_l2_norm_impl(ctx, a, eps, false);
  2607. }
  2608. struct ggml_tensor * ggml_l2_norm_inplace(
  2609. struct ggml_context * ctx,
  2610. struct ggml_tensor * a,
  2611. float eps) {
  2612. return ggml_l2_norm_impl(ctx, a, eps, true);
  2613. }
  2614. // ggml_mul_mat
  2615. static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
  2616. static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
  2617. return (t0->ne[0] == t1->ne[0]) &&
  2618. (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable
  2619. (t1->ne[3]%t0->ne[3] == 0);
  2620. }
  2621. struct ggml_tensor * ggml_mul_mat(
  2622. struct ggml_context * ctx,
  2623. struct ggml_tensor * a,
  2624. struct ggml_tensor * b) {
  2625. GGML_ASSERT(ggml_can_mul_mat(a, b));
  2626. GGML_ASSERT(!ggml_is_transposed(a));
  2627. const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] };
  2628. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
  2629. result->op = GGML_OP_MUL_MAT;
  2630. result->src[0] = a;
  2631. result->src[1] = b;
  2632. return result;
  2633. }
  2634. void ggml_mul_mat_set_prec(
  2635. struct ggml_tensor * a,
  2636. enum ggml_prec prec) {
  2637. GGML_ASSERT(a->op == GGML_OP_MUL_MAT);
  2638. const int32_t prec_i32 = (int32_t) prec;
  2639. ggml_set_op_params_i32(a, 0, prec_i32);
  2640. }
  2641. // ggml_mul_mat_id
  2642. /*
  2643. c = ggml_mul_mat_id(ctx, as, b, ids);
  2644. as -> [cols, rows, n_expert]
  2645. b -> [cols, n_expert_used, n_tokens]
  2646. ids -> [n_expert_used, n_tokens] (i32)
  2647. c -> [rows, n_expert_used, n_tokens]
  2648. in b, n_expert_used can be broadcasted to match the n_expert_used of ids
  2649. c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids
  2650. */
  2651. struct ggml_tensor * ggml_mul_mat_id(
  2652. struct ggml_context * ctx,
  2653. struct ggml_tensor * as,
  2654. struct ggml_tensor * b,
  2655. struct ggml_tensor * ids) {
  2656. GGML_ASSERT(!ggml_is_transposed(as));
  2657. GGML_ASSERT(ids->type == GGML_TYPE_I32);
  2658. GGML_ASSERT(as->ne[3] == 1); // as is 3d (one matrix per expert)
  2659. GGML_ASSERT(b->ne[3] == 1); // b is 3d
  2660. GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d
  2661. GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row
  2662. GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat
  2663. GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast
  2664. const int64_t ne[4] = { as->ne[1], ids->ne[0], b->ne[2], 1 };
  2665. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
  2666. result->op = GGML_OP_MUL_MAT_ID;
  2667. result->src[0] = as;
  2668. result->src[1] = b;
  2669. result->src[2] = ids;
  2670. return result;
  2671. }
  2672. // ggml_out_prod
  2673. static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
  2674. static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
  2675. return (t0->ne[1] == t1->ne[1]) &&
  2676. (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable
  2677. (t1->ne[3]%t0->ne[3] == 0);
  2678. }
  2679. struct ggml_tensor * ggml_out_prod(
  2680. struct ggml_context * ctx,
  2681. struct ggml_tensor * a,
  2682. struct ggml_tensor * b) {
  2683. GGML_ASSERT(ggml_can_out_prod(a, b));
  2684. GGML_ASSERT(!ggml_is_transposed(a));
  2685. // a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3]
  2686. const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] };
  2687. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
  2688. result->op = GGML_OP_OUT_PROD;
  2689. result->src[0] = a;
  2690. result->src[1] = b;
  2691. return result;
  2692. }
  2693. // ggml_scale
  2694. static struct ggml_tensor * ggml_scale_impl(
  2695. struct ggml_context * ctx,
  2696. struct ggml_tensor * a,
  2697. float s,
  2698. float b,
  2699. bool inplace) {
  2700. GGML_ASSERT(ggml_is_padded_1d(a));
  2701. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  2702. float params[2] = { s, b };
  2703. ggml_set_op_params(result, &params, sizeof(params));
  2704. result->op = GGML_OP_SCALE;
  2705. result->src[0] = a;
  2706. return result;
  2707. }
  2708. struct ggml_tensor * ggml_scale(
  2709. struct ggml_context * ctx,
  2710. struct ggml_tensor * a,
  2711. float s) {
  2712. return ggml_scale_impl(ctx, a, s, 0.0, false);
  2713. }
  2714. struct ggml_tensor * ggml_scale_inplace(
  2715. struct ggml_context * ctx,
  2716. struct ggml_tensor * a,
  2717. float s) {
  2718. return ggml_scale_impl(ctx, a, s, 0.0, true);
  2719. }
  2720. struct ggml_tensor * ggml_scale_bias(
  2721. struct ggml_context * ctx,
  2722. struct ggml_tensor * a,
  2723. float s,
  2724. float b) {
  2725. return ggml_scale_impl(ctx, a, s, b, false);
  2726. }
  2727. struct ggml_tensor * ggml_scale_bias_inplace(
  2728. struct ggml_context * ctx,
  2729. struct ggml_tensor * a,
  2730. float s,
  2731. float b) {
  2732. return ggml_scale_impl(ctx, a, s, b, true);
  2733. }
  2734. // ggml_set
  2735. static struct ggml_tensor * ggml_set_impl(
  2736. struct ggml_context * ctx,
  2737. struct ggml_tensor * a,
  2738. struct ggml_tensor * b,
  2739. size_t nb1,
  2740. size_t nb2,
  2741. size_t nb3,
  2742. size_t offset,
  2743. bool inplace) {
  2744. GGML_ASSERT(ggml_nelements(a) >= ggml_nelements(b));
  2745. // make a view of the destination
  2746. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  2747. GGML_ASSERT(offset < (size_t)(1 << 30));
  2748. int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 };
  2749. ggml_set_op_params(result, params, sizeof(params));
  2750. result->op = GGML_OP_SET;
  2751. result->src[0] = a;
  2752. result->src[1] = b;
  2753. return result;
  2754. }
  2755. struct ggml_tensor * ggml_set(
  2756. struct ggml_context * ctx,
  2757. struct ggml_tensor * a,
  2758. struct ggml_tensor * b,
  2759. size_t nb1,
  2760. size_t nb2,
  2761. size_t nb3,
  2762. size_t offset) {
  2763. return ggml_set_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
  2764. }
  2765. struct ggml_tensor * ggml_set_inplace(
  2766. struct ggml_context * ctx,
  2767. struct ggml_tensor * a,
  2768. struct ggml_tensor * b,
  2769. size_t nb1,
  2770. size_t nb2,
  2771. size_t nb3,
  2772. size_t offset) {
  2773. return ggml_set_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
  2774. }
  2775. struct ggml_tensor * ggml_set_1d(
  2776. struct ggml_context * ctx,
  2777. struct ggml_tensor * a,
  2778. struct ggml_tensor * b,
  2779. size_t offset) {
  2780. return ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, false);
  2781. }
  2782. struct ggml_tensor * ggml_set_1d_inplace(
  2783. struct ggml_context * ctx,
  2784. struct ggml_tensor * a,
  2785. struct ggml_tensor * b,
  2786. size_t offset) {
  2787. return ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, true);
  2788. }
  2789. struct ggml_tensor * ggml_set_2d(
  2790. struct ggml_context * ctx,
  2791. struct ggml_tensor * a,
  2792. struct ggml_tensor * b,
  2793. size_t nb1,
  2794. size_t offset) {
  2795. return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false);
  2796. }
  2797. struct ggml_tensor * ggml_set_2d_inplace(
  2798. struct ggml_context * ctx,
  2799. struct ggml_tensor * a,
  2800. struct ggml_tensor * b,
  2801. size_t nb1,
  2802. size_t offset) {
  2803. return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, true);
  2804. }
  2805. // ggml_cpy
  2806. static struct ggml_tensor * ggml_cpy_impl(
  2807. struct ggml_context * ctx,
  2808. struct ggml_tensor * a,
  2809. struct ggml_tensor * b) {
  2810. GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
  2811. // make a view of the destination
  2812. struct ggml_tensor * result = ggml_view_tensor(ctx, b);
  2813. if (strlen(b->name) > 0) {
  2814. ggml_format_name(result, "%s (copy of %s)", b->name, a->name);
  2815. } else {
  2816. ggml_format_name(result, "%s (copy)", a->name);
  2817. }
  2818. result->op = GGML_OP_CPY;
  2819. result->src[0] = a;
  2820. result->src[1] = b;
  2821. return result;
  2822. }
  2823. struct ggml_tensor * ggml_cpy(
  2824. struct ggml_context * ctx,
  2825. struct ggml_tensor * a,
  2826. struct ggml_tensor * b) {
  2827. return ggml_cpy_impl(ctx, a, b);
  2828. }
  2829. struct ggml_tensor * ggml_cast(
  2830. struct ggml_context * ctx,
  2831. struct ggml_tensor * a,
  2832. enum ggml_type type) {
  2833. struct ggml_tensor * result = ggml_new_tensor(ctx, type, GGML_MAX_DIMS, a->ne);
  2834. ggml_format_name(result, "%s (copy)", a->name);
  2835. result->op = GGML_OP_CPY;
  2836. result->src[0] = a;
  2837. result->src[1] = result;
  2838. return result;
  2839. }
  2840. // ggml_cont
  2841. static struct ggml_tensor * ggml_cont_impl(
  2842. struct ggml_context * ctx,
  2843. struct ggml_tensor * a) {
  2844. struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
  2845. ggml_format_name(result, "%s (cont)", a->name);
  2846. result->op = GGML_OP_CONT;
  2847. result->src[0] = a;
  2848. return result;
  2849. }
  2850. struct ggml_tensor * ggml_cont(
  2851. struct ggml_context * ctx,
  2852. struct ggml_tensor * a) {
  2853. return ggml_cont_impl(ctx, a);
  2854. }
  2855. // make contiguous, with new shape
  2856. GGML_API struct ggml_tensor * ggml_cont_1d(
  2857. struct ggml_context * ctx,
  2858. struct ggml_tensor * a,
  2859. int64_t ne0) {
  2860. return ggml_cont_4d(ctx, a, ne0, 1, 1, 1);
  2861. }
  2862. GGML_API struct ggml_tensor * ggml_cont_2d(
  2863. struct ggml_context * ctx,
  2864. struct ggml_tensor * a,
  2865. int64_t ne0,
  2866. int64_t ne1) {
  2867. return ggml_cont_4d(ctx, a, ne0, ne1, 1, 1);
  2868. }
  2869. GGML_API struct ggml_tensor * ggml_cont_3d(
  2870. struct ggml_context * ctx,
  2871. struct ggml_tensor * a,
  2872. int64_t ne0,
  2873. int64_t ne1,
  2874. int64_t ne2) {
  2875. return ggml_cont_4d(ctx, a, ne0, ne1, ne2, 1);
  2876. }
  2877. struct ggml_tensor * ggml_cont_4d(
  2878. struct ggml_context * ctx,
  2879. struct ggml_tensor * a,
  2880. int64_t ne0,
  2881. int64_t ne1,
  2882. int64_t ne2,
  2883. int64_t ne3) {
  2884. GGML_ASSERT(ggml_nelements(a) == (ne0*ne1*ne2*ne3));
  2885. struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
  2886. ggml_format_name(result, "%s (cont)", a->name);
  2887. result->op = GGML_OP_CONT;
  2888. result->src[0] = a;
  2889. return result;
  2890. }
  2891. // ggml_reshape
  2892. struct ggml_tensor * ggml_reshape(
  2893. struct ggml_context * ctx,
  2894. struct ggml_tensor * a,
  2895. struct ggml_tensor * b) {
  2896. GGML_ASSERT(ggml_is_contiguous(a));
  2897. // as only the shape of b is relevant, and not its memory layout, b is allowed to be non contiguous.
  2898. GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
  2899. struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b->ne, a, 0);
  2900. ggml_format_name(result, "%s (reshaped)", a->name);
  2901. result->op = GGML_OP_RESHAPE;
  2902. result->src[0] = a;
  2903. return result;
  2904. }
  2905. struct ggml_tensor * ggml_reshape_1d(
  2906. struct ggml_context * ctx,
  2907. struct ggml_tensor * a,
  2908. int64_t ne0) {
  2909. GGML_ASSERT(ggml_is_contiguous(a));
  2910. GGML_ASSERT(ggml_nelements(a) == ne0);
  2911. const int64_t ne[1] = { ne0 };
  2912. struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, ne, a, 0);
  2913. ggml_format_name(result, "%s (reshaped)", a->name);
  2914. result->op = GGML_OP_RESHAPE;
  2915. result->src[0] = a;
  2916. return result;
  2917. }
  2918. struct ggml_tensor * ggml_reshape_2d(
  2919. struct ggml_context * ctx,
  2920. struct ggml_tensor * a,
  2921. int64_t ne0,
  2922. int64_t ne1) {
  2923. GGML_ASSERT(ggml_is_contiguous(a));
  2924. GGML_ASSERT(ggml_nelements(a) == ne0*ne1);
  2925. const int64_t ne[2] = { ne0, ne1 };
  2926. struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, a, 0);
  2927. ggml_format_name(result, "%s (reshaped)", a->name);
  2928. result->op = GGML_OP_RESHAPE;
  2929. result->src[0] = a;
  2930. return result;
  2931. }
  2932. struct ggml_tensor * ggml_reshape_3d(
  2933. struct ggml_context * ctx,
  2934. struct ggml_tensor * a,
  2935. int64_t ne0,
  2936. int64_t ne1,
  2937. int64_t ne2) {
  2938. GGML_ASSERT(ggml_is_contiguous(a));
  2939. GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2);
  2940. const int64_t ne[3] = { ne0, ne1, ne2 };
  2941. struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, a, 0);
  2942. ggml_format_name(result, "%s (reshaped)", a->name);
  2943. result->op = GGML_OP_RESHAPE;
  2944. result->src[0] = a;
  2945. return result;
  2946. }
  2947. struct ggml_tensor * ggml_reshape_4d(
  2948. struct ggml_context * ctx,
  2949. struct ggml_tensor * a,
  2950. int64_t ne0,
  2951. int64_t ne1,
  2952. int64_t ne2,
  2953. int64_t ne3) {
  2954. GGML_ASSERT(ggml_is_contiguous(a));
  2955. GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3);
  2956. const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
  2957. struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, a, 0);
  2958. ggml_format_name(result, "%s (reshaped)", a->name);
  2959. result->op = GGML_OP_RESHAPE;
  2960. result->src[0] = a;
  2961. return result;
  2962. }
  2963. static struct ggml_tensor * ggml_view_impl(
  2964. struct ggml_context * ctx,
  2965. struct ggml_tensor * a,
  2966. int n_dims,
  2967. const int64_t * ne,
  2968. size_t offset) {
  2969. struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, n_dims, ne, a, offset);
  2970. ggml_format_name(result, "%s (view)", a->name);
  2971. ggml_set_op_params(result, &offset, sizeof(offset));
  2972. result->op = GGML_OP_VIEW;
  2973. result->src[0] = a;
  2974. return result;
  2975. }
  2976. // ggml_view_1d
  2977. struct ggml_tensor * ggml_view_1d(
  2978. struct ggml_context * ctx,
  2979. struct ggml_tensor * a,
  2980. int64_t ne0,
  2981. size_t offset) {
  2982. struct ggml_tensor * result = ggml_view_impl(ctx, a, 1, &ne0, offset);
  2983. return result;
  2984. }
  2985. // ggml_view_2d
  2986. struct ggml_tensor * ggml_view_2d(
  2987. struct ggml_context * ctx,
  2988. struct ggml_tensor * a,
  2989. int64_t ne0,
  2990. int64_t ne1,
  2991. size_t nb1,
  2992. size_t offset) {
  2993. const int64_t ne[2] = { ne0, ne1 };
  2994. struct ggml_tensor * result = ggml_view_impl(ctx, a, 2, ne, offset);
  2995. result->nb[1] = nb1;
  2996. result->nb[2] = result->nb[1]*ne1;
  2997. result->nb[3] = result->nb[2];
  2998. return result;
  2999. }
  3000. // ggml_view_3d
  3001. struct ggml_tensor * ggml_view_3d(
  3002. struct ggml_context * ctx,
  3003. struct ggml_tensor * a,
  3004. int64_t ne0,
  3005. int64_t ne1,
  3006. int64_t ne2,
  3007. size_t nb1,
  3008. size_t nb2,
  3009. size_t offset) {
  3010. const int64_t ne[3] = { ne0, ne1, ne2 };
  3011. struct ggml_tensor * result = ggml_view_impl(ctx, a, 3, ne, offset);
  3012. result->nb[1] = nb1;
  3013. result->nb[2] = nb2;
  3014. result->nb[3] = result->nb[2]*ne2;
  3015. return result;
  3016. }
  3017. // ggml_view_4d
  3018. struct ggml_tensor * ggml_view_4d(
  3019. struct ggml_context * ctx,
  3020. struct ggml_tensor * a,
  3021. int64_t ne0,
  3022. int64_t ne1,
  3023. int64_t ne2,
  3024. int64_t ne3,
  3025. size_t nb1,
  3026. size_t nb2,
  3027. size_t nb3,
  3028. size_t offset) {
  3029. const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
  3030. struct ggml_tensor * result = ggml_view_impl(ctx, a, 4, ne, offset);
  3031. result->nb[1] = nb1;
  3032. result->nb[2] = nb2;
  3033. result->nb[3] = nb3;
  3034. return result;
  3035. }
  3036. // ggml_permute
  3037. struct ggml_tensor * ggml_permute(
  3038. struct ggml_context * ctx,
  3039. struct ggml_tensor * a,
  3040. int axis0,
  3041. int axis1,
  3042. int axis2,
  3043. int axis3) {
  3044. GGML_ASSERT(axis0 >= 0 && axis0 < GGML_MAX_DIMS);
  3045. GGML_ASSERT(axis1 >= 0 && axis1 < GGML_MAX_DIMS);
  3046. GGML_ASSERT(axis2 >= 0 && axis2 < GGML_MAX_DIMS);
  3047. GGML_ASSERT(axis3 >= 0 && axis3 < GGML_MAX_DIMS);
  3048. GGML_ASSERT(axis0 != axis1);
  3049. GGML_ASSERT(axis0 != axis2);
  3050. GGML_ASSERT(axis0 != axis3);
  3051. GGML_ASSERT(axis1 != axis2);
  3052. GGML_ASSERT(axis1 != axis3);
  3053. GGML_ASSERT(axis2 != axis3);
  3054. struct ggml_tensor * result = ggml_view_tensor(ctx, a);
  3055. ggml_format_name(result, "%s (permuted)", a->name);
  3056. int ne[GGML_MAX_DIMS];
  3057. int nb[GGML_MAX_DIMS];
  3058. ne[axis0] = a->ne[0];
  3059. ne[axis1] = a->ne[1];
  3060. ne[axis2] = a->ne[2];
  3061. ne[axis3] = a->ne[3];
  3062. nb[axis0] = a->nb[0];
  3063. nb[axis1] = a->nb[1];
  3064. nb[axis2] = a->nb[2];
  3065. nb[axis3] = a->nb[3];
  3066. result->ne[0] = ne[0];
  3067. result->ne[1] = ne[1];
  3068. result->ne[2] = ne[2];
  3069. result->ne[3] = ne[3];
  3070. result->nb[0] = nb[0];
  3071. result->nb[1] = nb[1];
  3072. result->nb[2] = nb[2];
  3073. result->nb[3] = nb[3];
  3074. result->op = GGML_OP_PERMUTE;
  3075. result->src[0] = a;
  3076. int32_t params[] = { axis0, axis1, axis2, axis3 };
  3077. ggml_set_op_params(result, params, sizeof(params));
  3078. return result;
  3079. }
  3080. // ggml_transpose
  3081. struct ggml_tensor * ggml_transpose(
  3082. struct ggml_context * ctx,
  3083. struct ggml_tensor * a) {
  3084. struct ggml_tensor * result = ggml_view_tensor(ctx, a);
  3085. ggml_format_name(result, "%s (transposed)", a->name);
  3086. result->ne[0] = a->ne[1];
  3087. result->ne[1] = a->ne[0];
  3088. result->nb[0] = a->nb[1];
  3089. result->nb[1] = a->nb[0];
  3090. result->op = GGML_OP_TRANSPOSE;
  3091. result->src[0] = a;
  3092. return result;
  3093. }
  3094. // ggml_get_rows
  3095. struct ggml_tensor * ggml_get_rows(
  3096. struct ggml_context * ctx,
  3097. struct ggml_tensor * a,
  3098. struct ggml_tensor * b) {
  3099. GGML_ASSERT(a->ne[2] == b->ne[1]);
  3100. GGML_ASSERT(a->ne[3] == b->ne[2]);
  3101. GGML_ASSERT(b->ne[3] == 1);
  3102. GGML_ASSERT(b->type == GGML_TYPE_I32);
  3103. // TODO: implement non F32 return
  3104. enum ggml_type type = GGML_TYPE_F32;
  3105. if (a->type == GGML_TYPE_I32) {
  3106. type = a->type;
  3107. }
  3108. struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, a->ne[0], b->ne[0], b->ne[1], b->ne[2]);
  3109. result->op = GGML_OP_GET_ROWS;
  3110. result->src[0] = a;
  3111. result->src[1] = b;
  3112. return result;
  3113. }
  3114. // ggml_get_rows_back
  3115. struct ggml_tensor * ggml_get_rows_back(
  3116. struct ggml_context * ctx,
  3117. struct ggml_tensor * a,
  3118. struct ggml_tensor * b,
  3119. struct ggml_tensor * c) {
  3120. GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32);
  3121. GGML_ASSERT(ggml_is_matrix(c) && (a->ne[0] == c->ne[0]));
  3122. // TODO: implement non F32 return
  3123. //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
  3124. struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, c->ne[0], c->ne[1]);
  3125. result->op = GGML_OP_GET_ROWS_BACK;
  3126. result->src[0] = a;
  3127. result->src[1] = b;
  3128. return result;
  3129. }
  3130. // ggml_set_rows
  3131. struct ggml_tensor * ggml_set_rows(
  3132. struct ggml_context * ctx,
  3133. struct ggml_tensor * a,
  3134. struct ggml_tensor * b,
  3135. struct ggml_tensor * c) {
  3136. GGML_ASSERT(a->ne[0] == b->ne[0]);
  3137. GGML_ASSERT(a->ne[2] == b->ne[2]);
  3138. GGML_ASSERT(a->ne[3] == b->ne[3]);
  3139. GGML_ASSERT(b->ne[1] == c->ne[0]);
  3140. GGML_ASSERT(b->ne[2] % c->ne[1] == 0);
  3141. GGML_ASSERT(b->ne[3] % c->ne[2] == 0);
  3142. GGML_ASSERT(c->ne[3] == 1);
  3143. GGML_ASSERT(b->type == GGML_TYPE_F32);
  3144. GGML_ASSERT(c->type == GGML_TYPE_I64 || c->type == GGML_TYPE_I32);
  3145. GGML_ASSERT(ggml_is_contiguous_rows(a));
  3146. GGML_ASSERT(ggml_is_contiguous_rows(b));
  3147. struct ggml_tensor * result = ggml_view_tensor(ctx, a);
  3148. result->op = GGML_OP_SET_ROWS;
  3149. result->src[0] = b;
  3150. result->src[1] = c;
  3151. result->src[2] = a; // note: order is weird due to legacy reasons (https://github.com/ggml-org/llama.cpp/pull/16063#discussion_r2385795931)
  3152. return result;
  3153. }
  3154. // ggml_diag
  3155. struct ggml_tensor * ggml_diag(
  3156. struct ggml_context * ctx,
  3157. struct ggml_tensor * a) {
  3158. GGML_ASSERT(a->ne[1] == 1);
  3159. const int64_t ne[4] = { a->ne[0], a->ne[0], a->ne[2], a->ne[3] };
  3160. struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, 4, ne);
  3161. result->op = GGML_OP_DIAG;
  3162. result->src[0] = a;
  3163. return result;
  3164. }
  3165. // ggml_diag_mask_inf
  3166. static struct ggml_tensor * ggml_diag_mask_inf_impl(
  3167. struct ggml_context * ctx,
  3168. struct ggml_tensor * a,
  3169. int n_past,
  3170. bool inplace) {
  3171. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  3172. int32_t params[] = { n_past };
  3173. ggml_set_op_params(result, params, sizeof(params));
  3174. result->op = GGML_OP_DIAG_MASK_INF;
  3175. result->src[0] = a;
  3176. return result;
  3177. }
  3178. struct ggml_tensor * ggml_diag_mask_inf(
  3179. struct ggml_context * ctx,
  3180. struct ggml_tensor * a,
  3181. int n_past) {
  3182. return ggml_diag_mask_inf_impl(ctx, a, n_past, false);
  3183. }
  3184. struct ggml_tensor * ggml_diag_mask_inf_inplace(
  3185. struct ggml_context * ctx,
  3186. struct ggml_tensor * a,
  3187. int n_past) {
  3188. return ggml_diag_mask_inf_impl(ctx, a, n_past, true);
  3189. }
  3190. // ggml_diag_mask_zero
  3191. static struct ggml_tensor * ggml_diag_mask_zero_impl(
  3192. struct ggml_context * ctx,
  3193. struct ggml_tensor * a,
  3194. int n_past,
  3195. bool inplace) {
  3196. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  3197. int32_t params[] = { n_past };
  3198. ggml_set_op_params(result, params, sizeof(params));
  3199. result->op = GGML_OP_DIAG_MASK_ZERO;
  3200. result->src[0] = a;
  3201. return result;
  3202. }
  3203. struct ggml_tensor * ggml_diag_mask_zero(
  3204. struct ggml_context * ctx,
  3205. struct ggml_tensor * a,
  3206. int n_past) {
  3207. return ggml_diag_mask_zero_impl(ctx, a, n_past, false);
  3208. }
  3209. struct ggml_tensor * ggml_diag_mask_zero_inplace(
  3210. struct ggml_context * ctx,
  3211. struct ggml_tensor * a,
  3212. int n_past) {
  3213. return ggml_diag_mask_zero_impl(ctx, a, n_past, true);
  3214. }
  3215. // ggml_soft_max
  3216. static struct ggml_tensor * ggml_soft_max_impl(
  3217. struct ggml_context * ctx,
  3218. struct ggml_tensor * a,
  3219. struct ggml_tensor * mask,
  3220. float scale,
  3221. float max_bias,
  3222. bool inplace) {
  3223. GGML_ASSERT(ggml_is_contiguous(a));
  3224. if (mask) {
  3225. GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
  3226. GGML_ASSERT(ggml_is_contiguous(mask));
  3227. GGML_ASSERT(mask->ne[0] == a->ne[0]);
  3228. GGML_ASSERT(mask->ne[1] >= a->ne[1]);
  3229. GGML_ASSERT(a->ne[2]%mask->ne[2] == 0);
  3230. GGML_ASSERT(a->ne[3]%mask->ne[3] == 0);
  3231. }
  3232. if (max_bias > 0.0f) {
  3233. GGML_ASSERT(mask);
  3234. }
  3235. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  3236. float params[] = { scale, max_bias };
  3237. ggml_set_op_params(result, params, sizeof(params));
  3238. result->op = GGML_OP_SOFT_MAX;
  3239. result->src[0] = a;
  3240. result->src[1] = mask;
  3241. return result;
  3242. }
  3243. struct ggml_tensor * ggml_soft_max(
  3244. struct ggml_context * ctx,
  3245. struct ggml_tensor * a) {
  3246. return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, false);
  3247. }
  3248. struct ggml_tensor * ggml_soft_max_inplace(
  3249. struct ggml_context * ctx,
  3250. struct ggml_tensor * a) {
  3251. return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, true);
  3252. }
  3253. struct ggml_tensor * ggml_soft_max_ext(
  3254. struct ggml_context * ctx,
  3255. struct ggml_tensor * a,
  3256. struct ggml_tensor * mask,
  3257. float scale,
  3258. float max_bias) {
  3259. return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
  3260. }
  3261. struct ggml_tensor * ggml_soft_max_ext_inplace(
  3262. struct ggml_context * ctx,
  3263. struct ggml_tensor * a,
  3264. struct ggml_tensor * mask,
  3265. float scale,
  3266. float max_bias) {
  3267. return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, true);
  3268. }
  3269. void ggml_soft_max_add_sinks(
  3270. struct ggml_tensor * a,
  3271. struct ggml_tensor * sinks) {
  3272. if (!sinks) {
  3273. a->src[2] = NULL;
  3274. return;
  3275. }
  3276. GGML_ASSERT(a->op == GGML_OP_SOFT_MAX);
  3277. GGML_ASSERT(a->src[2] == NULL);
  3278. GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]);
  3279. GGML_ASSERT(sinks->type == GGML_TYPE_F32);
  3280. a->src[2] = sinks;
  3281. }
  3282. // ggml_soft_max_ext_back
  3283. static struct ggml_tensor * ggml_soft_max_ext_back_impl(
  3284. struct ggml_context * ctx,
  3285. struct ggml_tensor * a,
  3286. struct ggml_tensor * b,
  3287. float scale,
  3288. float max_bias,
  3289. bool inplace) {
  3290. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  3291. result->op = GGML_OP_SOFT_MAX_BACK;
  3292. result->src[0] = a;
  3293. result->src[1] = b;
  3294. memcpy((float *) result->op_params + 0, &scale, sizeof(float));
  3295. memcpy((float *) result->op_params + 1, &max_bias, sizeof(float));
  3296. return result;
  3297. }
  3298. struct ggml_tensor * ggml_soft_max_ext_back(
  3299. struct ggml_context * ctx,
  3300. struct ggml_tensor * a,
  3301. struct ggml_tensor * b,
  3302. float scale,
  3303. float max_bias) {
  3304. return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, false);
  3305. }
  3306. struct ggml_tensor * ggml_soft_max_ext_back_inplace(
  3307. struct ggml_context * ctx,
  3308. struct ggml_tensor * a,
  3309. struct ggml_tensor * b,
  3310. float scale,
  3311. float max_bias) {
  3312. return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, true);
  3313. }
  3314. // ggml_rope
  3315. static struct ggml_tensor * ggml_rope_impl(
  3316. struct ggml_context * ctx,
  3317. struct ggml_tensor * a,
  3318. struct ggml_tensor * b,
  3319. struct ggml_tensor * c,
  3320. int n_dims,
  3321. int sections[GGML_MROPE_SECTIONS],
  3322. int mode,
  3323. int n_ctx_orig,
  3324. float freq_base,
  3325. float freq_scale,
  3326. float ext_factor,
  3327. float attn_factor,
  3328. float beta_fast,
  3329. float beta_slow,
  3330. bool inplace) {
  3331. GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
  3332. GGML_ASSERT(ggml_is_vector(b));
  3333. GGML_ASSERT(b->type == GGML_TYPE_I32);
  3334. bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;
  3335. if (mrope_used) {
  3336. GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token
  3337. } else {
  3338. GGML_ASSERT(a->ne[2] == b->ne[0]);
  3339. }
  3340. if (c) {
  3341. GGML_ASSERT(c->type == GGML_TYPE_F32);
  3342. GGML_ASSERT(c->ne[0] >= n_dims / 2);
  3343. }
  3344. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  3345. int32_t params[15] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
  3346. memcpy(params + 5, &freq_base, sizeof(float));
  3347. memcpy(params + 6, &freq_scale, sizeof(float));
  3348. memcpy(params + 7, &ext_factor, sizeof(float));
  3349. memcpy(params + 8, &attn_factor, sizeof(float));
  3350. memcpy(params + 9, &beta_fast, sizeof(float));
  3351. memcpy(params + 10, &beta_slow, sizeof(float));
  3352. if (mrope_used && sections) {
  3353. memcpy(params + 11, sections, sizeof(int32_t) * GGML_MROPE_SECTIONS);
  3354. } else {
  3355. memset(params + 11, 0, sizeof(int32_t) * GGML_MROPE_SECTIONS);
  3356. }
  3357. ggml_set_op_params(result, params, sizeof(params));
  3358. result->op = GGML_OP_ROPE;
  3359. result->src[0] = a;
  3360. result->src[1] = b;
  3361. result->src[2] = c;
  3362. return result;
  3363. }
  3364. struct ggml_tensor * ggml_rope(
  3365. struct ggml_context * ctx,
  3366. struct ggml_tensor * a,
  3367. struct ggml_tensor * b,
  3368. int n_dims,
  3369. int mode) {
  3370. return ggml_rope_impl(
  3371. ctx, a, b, NULL, n_dims, NULL, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false
  3372. );
  3373. }
  3374. struct ggml_tensor * ggml_rope_multi(
  3375. struct ggml_context * ctx,
  3376. struct ggml_tensor * a,
  3377. struct ggml_tensor * b,
  3378. struct ggml_tensor * c,
  3379. int n_dims,
  3380. int sections[GGML_MROPE_SECTIONS],
  3381. int mode,
  3382. int n_ctx_orig,
  3383. float freq_base,
  3384. float freq_scale,
  3385. float ext_factor,
  3386. float attn_factor,
  3387. float beta_fast,
  3388. float beta_slow) {
  3389. return ggml_rope_impl(
  3390. ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale,
  3391. ext_factor, attn_factor, beta_fast, beta_slow, false
  3392. );
  3393. }
  3394. struct ggml_tensor * ggml_rope_multi_inplace(
  3395. struct ggml_context * ctx,
  3396. struct ggml_tensor * a,
  3397. struct ggml_tensor * b,
  3398. struct ggml_tensor * c,
  3399. int n_dims,
  3400. int sections[GGML_MROPE_SECTIONS],
  3401. int mode,
  3402. int n_ctx_orig,
  3403. float freq_base,
  3404. float freq_scale,
  3405. float ext_factor,
  3406. float attn_factor,
  3407. float beta_fast,
  3408. float beta_slow) {
  3409. return ggml_rope_impl(
  3410. ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale,
  3411. ext_factor, attn_factor, beta_fast, beta_slow, true
  3412. );
  3413. }
  3414. struct ggml_tensor * ggml_rope_inplace(
  3415. struct ggml_context * ctx,
  3416. struct ggml_tensor * a,
  3417. struct ggml_tensor * b,
  3418. int n_dims,
  3419. int mode) {
  3420. return ggml_rope_impl(
  3421. ctx, a, b, NULL, n_dims, NULL, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true
  3422. );
  3423. }
  3424. struct ggml_tensor * ggml_rope_ext(
  3425. struct ggml_context * ctx,
  3426. struct ggml_tensor * a,
  3427. struct ggml_tensor * b,
  3428. struct ggml_tensor * c,
  3429. int n_dims,
  3430. int mode,
  3431. int n_ctx_orig,
  3432. float freq_base,
  3433. float freq_scale,
  3434. float ext_factor,
  3435. float attn_factor,
  3436. float beta_fast,
  3437. float beta_slow) {
  3438. return ggml_rope_impl(
  3439. ctx, a, b, c, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
  3440. ext_factor, attn_factor, beta_fast, beta_slow, false
  3441. );
  3442. }
  3443. struct ggml_tensor * ggml_rope_ext_inplace(
  3444. struct ggml_context * ctx,
  3445. struct ggml_tensor * a,
  3446. struct ggml_tensor * b,
  3447. struct ggml_tensor * c,
  3448. int n_dims,
  3449. int mode,
  3450. int n_ctx_orig,
  3451. float freq_base,
  3452. float freq_scale,
  3453. float ext_factor,
  3454. float attn_factor,
  3455. float beta_fast,
  3456. float beta_slow) {
  3457. return ggml_rope_impl(
  3458. ctx, a, b, c, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
  3459. ext_factor, attn_factor, beta_fast, beta_slow, true
  3460. );
  3461. }
  3462. struct ggml_tensor * ggml_rope_custom(
  3463. struct ggml_context * ctx,
  3464. struct ggml_tensor * a,
  3465. struct ggml_tensor * b,
  3466. int n_dims,
  3467. int mode,
  3468. int n_ctx_orig,
  3469. float freq_base,
  3470. float freq_scale,
  3471. float ext_factor,
  3472. float attn_factor,
  3473. float beta_fast,
  3474. float beta_slow) {
  3475. return ggml_rope_impl(
  3476. ctx, a, b, NULL, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
  3477. ext_factor, attn_factor, beta_fast, beta_slow, false
  3478. );
  3479. }
  3480. struct ggml_tensor * ggml_rope_custom_inplace(
  3481. struct ggml_context * ctx,
  3482. struct ggml_tensor * a,
  3483. struct ggml_tensor * b,
  3484. int n_dims,
  3485. int mode,
  3486. int n_ctx_orig,
  3487. float freq_base,
  3488. float freq_scale,
  3489. float ext_factor,
  3490. float attn_factor,
  3491. float beta_fast,
  3492. float beta_slow) {
  3493. return ggml_rope_impl(
  3494. ctx, a, b, NULL, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
  3495. ext_factor, attn_factor, beta_fast, beta_slow, true
  3496. );
  3497. }
  3498. // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
  3499. // `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
  3500. static float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) {
  3501. return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
  3502. }
  3503. void ggml_rope_yarn_corr_dims(
  3504. int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
  3505. ) {
  3506. // start and end correction dims
  3507. float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base));
  3508. float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base));
  3509. dims[0] = MAX(0, start);
  3510. dims[1] = MIN(n_dims - 1, end);
  3511. }
  3512. // ggml_rope_back
  3513. struct ggml_tensor * ggml_rope_ext_back(
  3514. struct ggml_context * ctx,
  3515. struct ggml_tensor * a,
  3516. struct ggml_tensor * b,
  3517. struct ggml_tensor * c,
  3518. int n_dims,
  3519. int mode,
  3520. int n_ctx_orig,
  3521. float freq_base,
  3522. float freq_scale,
  3523. float ext_factor,
  3524. float attn_factor,
  3525. float beta_fast,
  3526. float beta_slow) {
  3527. struct ggml_tensor * result = ggml_rope_ext(
  3528. ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
  3529. result->op = GGML_OP_ROPE_BACK;
  3530. return result;
  3531. }
  3532. struct ggml_tensor * ggml_rope_multi_back(
  3533. struct ggml_context * ctx,
  3534. struct ggml_tensor * a,
  3535. struct ggml_tensor * b,
  3536. struct ggml_tensor * c,
  3537. int n_dims,
  3538. int sections[4],
  3539. int mode,
  3540. int n_ctx_orig,
  3541. float freq_base,
  3542. float freq_scale,
  3543. float ext_factor,
  3544. float attn_factor,
  3545. float beta_fast,
  3546. float beta_slow) {
  3547. struct ggml_tensor * result = ggml_rope_multi(
  3548. ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
  3549. result->op = GGML_OP_ROPE_BACK;
  3550. return result;
  3551. }
  3552. // ggml_clamp
  3553. struct ggml_tensor * ggml_clamp(
  3554. struct ggml_context * ctx,
  3555. struct ggml_tensor * a,
  3556. float min,
  3557. float max) {
  3558. // TODO: when implement backward, fix this:
  3559. struct ggml_tensor * result = ggml_view_tensor(ctx, a);
  3560. float params[] = { min, max };
  3561. ggml_set_op_params(result, params, sizeof(params));
  3562. result->op = GGML_OP_CLAMP;
  3563. result->src[0] = a;
  3564. return result;
  3565. }
  3566. static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
  3567. return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
  3568. }
  3569. // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
  3570. // a: [OC,IC, KH, KW]
  3571. // b: [N, IC, IH, IW]
  3572. // result: [N, OH, OW, IC*KH*KW]
  3573. struct ggml_tensor * ggml_im2col(
  3574. struct ggml_context * ctx,
  3575. struct ggml_tensor * a,
  3576. struct ggml_tensor * b,
  3577. int s0,
  3578. int s1,
  3579. int p0,
  3580. int p1,
  3581. int d0,
  3582. int d1,
  3583. bool is_2D,
  3584. enum ggml_type dst_type) {
  3585. if (is_2D) {
  3586. GGML_ASSERT(a->ne[2] == b->ne[2]);
  3587. } else {
  3588. //GGML_ASSERT(b->ne[1] % a->ne[1] == 0);
  3589. GGML_ASSERT(b->ne[1] == a->ne[1]);
  3590. GGML_ASSERT(b->ne[3] == 1);
  3591. }
  3592. const int64_t OH = is_2D ? ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0;
  3593. const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
  3594. GGML_ASSERT((!is_2D || OH > 0) && "b too small compared to a");
  3595. GGML_ASSERT((OW > 0) && "b too small compared to a");
  3596. const int64_t ne[4] = {
  3597. is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0],
  3598. OW,
  3599. is_2D ? OH : b->ne[2],
  3600. is_2D ? b->ne[3] : 1,
  3601. };
  3602. struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne);
  3603. int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
  3604. ggml_set_op_params(result, params, sizeof(params));
  3605. result->op = GGML_OP_IM2COL;
  3606. result->src[0] = a;
  3607. result->src[1] = b;
  3608. return result;
  3609. }
  3610. struct ggml_tensor * ggml_im2col_back(
  3611. struct ggml_context * ctx,
  3612. struct ggml_tensor * a,
  3613. struct ggml_tensor * b,
  3614. int64_t * ne,
  3615. int s0,
  3616. int s1,
  3617. int p0,
  3618. int p1,
  3619. int d0,
  3620. int d1,
  3621. bool is_2D) {
  3622. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
  3623. int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
  3624. ggml_set_op_params(result, params, sizeof(params));
  3625. result->op = GGML_OP_IM2COL_BACK;
  3626. result->src[0] = a;
  3627. result->src[1] = b;
  3628. return result;
  3629. }
  3630. // ggml_conv_1d
  3631. struct ggml_tensor * ggml_conv_1d(
  3632. struct ggml_context * ctx,
  3633. struct ggml_tensor * a,
  3634. struct ggml_tensor * b,
  3635. int s0,
  3636. int p0,
  3637. int d0) {
  3638. struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); // [N, OL, IC * K]
  3639. struct ggml_tensor * result =
  3640. ggml_mul_mat(ctx,
  3641. ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K]
  3642. ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])); // [OC,IC, K] => [OC, IC * K]
  3643. result = ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL]
  3644. return result;
  3645. }
  3646. // ggml_conv_1d_ph
  3647. struct ggml_tensor* ggml_conv_1d_ph(
  3648. struct ggml_context * ctx,
  3649. struct ggml_tensor * a,
  3650. struct ggml_tensor * b,
  3651. int s,
  3652. int d) {
  3653. return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
  3654. }
  3655. // ggml_conv_1d_dw
  3656. struct ggml_tensor * ggml_conv_1d_dw(
  3657. struct ggml_context * ctx,
  3658. struct ggml_tensor * a,
  3659. struct ggml_tensor * b,
  3660. int s0,
  3661. int p0,
  3662. int d0) {
  3663. struct ggml_tensor * new_b = ggml_reshape_4d(ctx, b, b->ne[0], 1, b->ne[1], b->ne[2]);
  3664. struct ggml_tensor * im2col = ggml_im2col(ctx, a, new_b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16);
  3665. struct ggml_tensor * result = ggml_mul_mat(ctx, im2col, a);
  3666. result = ggml_reshape_3d(ctx, result, result->ne[0], result->ne[2], 1);
  3667. return result;
  3668. }
  3669. // ggml_conv_1d_dw_ph
  3670. struct ggml_tensor * ggml_conv_1d_dw_ph(
  3671. struct ggml_context * ctx,
  3672. struct ggml_tensor * a,
  3673. struct ggml_tensor * b,
  3674. int s0,
  3675. int d0) {
  3676. return ggml_conv_1d_dw(ctx, a, b, s0, a->ne[0] / 2, d0);
  3677. }
  3678. // ggml_conv_transpose_1d
  3679. static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
  3680. return (ins - 1) * s - 2 * p + d * (ks - 1) + 1;
  3681. }
  3682. GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
  3683. struct ggml_context * ctx,
  3684. struct ggml_tensor * a,
  3685. struct ggml_tensor * b,
  3686. int s0,
  3687. int p0,
  3688. int d0) {
  3689. GGML_ASSERT(ggml_is_matrix(b));
  3690. GGML_ASSERT(a->ne[2] == b->ne[1]);
  3691. GGML_ASSERT(a->ne[3] == 1);
  3692. GGML_ASSERT(p0 == 0);
  3693. GGML_ASSERT(d0 == 1);
  3694. const int64_t ne[4] = {
  3695. ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/),
  3696. a->ne[1], b->ne[2], 1,
  3697. };
  3698. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
  3699. int32_t params[] = { s0, p0, d0 };
  3700. ggml_set_op_params(result, params, sizeof(params));
  3701. result->op = GGML_OP_CONV_TRANSPOSE_1D;
  3702. result->src[0] = a;
  3703. result->src[1] = b;
  3704. return result;
  3705. }
  3706. // ggml_conv_2d
  3707. // a: [OC,IC, KH, KW]
  3708. // b: [N, IC, IH, IW]
  3709. // result: [N, OC, OH, OW]
  3710. struct ggml_tensor * ggml_conv_2d(
  3711. struct ggml_context * ctx,
  3712. struct ggml_tensor * a,
  3713. struct ggml_tensor * b,
  3714. int s0,
  3715. int s1,
  3716. int p0,
  3717. int p1,
  3718. int d0,
  3719. int d1) {
  3720. struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, a->type); // [N, OH, OW, IC * KH * KW]
  3721. struct ggml_tensor * result =
  3722. ggml_mul_mat(ctx,
  3723. ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW]
  3724. ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])); // [OC,IC, KH, KW] => [OC, IC * KH * KW]
  3725. result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], im2col->ne[3], a->ne[3]); // [OC, N, OH, OW]
  3726. result = ggml_cont(ctx, ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OH, OW]
  3727. return result;
  3728. }
  3729. // a: [OC*IC, KD, KH, KW]
  3730. // b: [N*IC, ID, IH, IW]
  3731. // result: [N*OD, OH, OW, IC * KD * KH * KW]
  3732. struct ggml_tensor * ggml_im2col_3d(
  3733. struct ggml_context * ctx,
  3734. struct ggml_tensor * a,
  3735. struct ggml_tensor * b,
  3736. int64_t IC,
  3737. int s0, // stride width
  3738. int s1, // stride height
  3739. int s2, // stride depth
  3740. int p0, // padding width
  3741. int p1, // padding height
  3742. int p2, // padding depth
  3743. int d0, // dilation width
  3744. int d1, // dilation height
  3745. int d2, // dilation depth
  3746. enum ggml_type dst_type) {
  3747. const int64_t N = b->ne[3] / IC;
  3748. const int64_t ID = b->ne[2];
  3749. const int64_t IH = b->ne[1];
  3750. const int64_t IW = b->ne[0];
  3751. const int64_t OC = a->ne[3] / IC;
  3752. UNUSED(OC);
  3753. const int64_t KD = a->ne[2];
  3754. const int64_t KH = a->ne[1];
  3755. const int64_t KW = a->ne[0];
  3756. const int64_t OD = ggml_calc_conv_output_size(ID, KD, s2, p2, d2);
  3757. const int64_t OH = ggml_calc_conv_output_size(IH, KH, s1, p1, d1);
  3758. const int64_t OW = ggml_calc_conv_output_size(IW, KW, s0, p0, d0);
  3759. GGML_ASSERT((OD > 0) && "b too small compared to a");
  3760. GGML_ASSERT((OH > 0) && "b too small compared to a");
  3761. GGML_ASSERT((OW > 0) && "b too small compared to a");
  3762. const int64_t ne[4] = {KW*KH*KD*IC, OW, OH, OD*N};
  3763. struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne);
  3764. int32_t params[] = { s0, s1, s2, p0, p1, p2, d0, d1, d2, (int32_t)IC};
  3765. ggml_set_op_params(result, params, sizeof(params));
  3766. result->op = GGML_OP_IM2COL_3D;
  3767. result->src[0] = a;
  3768. result->src[1] = b;
  3769. return result;
  3770. }
  3771. // a: [OC*IC, KD, KH, KW]
  3772. // b: [N*IC, ID, IH, IW]
  3773. // result: [N*OC, OD, OH, OW]
  3774. struct ggml_tensor * ggml_conv_3d(
  3775. struct ggml_context * ctx,
  3776. struct ggml_tensor * a,
  3777. struct ggml_tensor * b,
  3778. int64_t IC,
  3779. int s0, // stride width
  3780. int s1, // stride height
  3781. int s2, // stride depth
  3782. int p0, // padding width
  3783. int p1, // padding height
  3784. int p2, // padding depth
  3785. int d0, // dilation width
  3786. int d1, // dilation height
  3787. int d2 // dilation depth
  3788. ) {
  3789. struct ggml_tensor * im2col = ggml_im2col_3d(ctx, a, b, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, a->type); // [N*OD, OH, OW, IC * KD * KH * KW]
  3790. int64_t OC = a->ne[3] / IC;
  3791. int64_t N = b->ne[3] / IC;
  3792. struct ggml_tensor * result =
  3793. ggml_mul_mat(ctx,
  3794. ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N*OD, OH, OW, IC * KD * KH * KW] => [N*OD*OH*OW, IC * KD * KH * KW]
  3795. ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2] * IC), OC)); // [OC*IC, KD, KH, KW] => [OC, IC * KD * KH * KW]
  3796. int64_t OD = im2col->ne[3] / N;
  3797. result = ggml_reshape_4d(ctx, result, im2col->ne[1]*im2col->ne[2], OD, N, OC); // [OC, N*OD*OH*OW] => [OC, N, OD, OH*OW]
  3798. result = ggml_cont(ctx, ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OD, OH*OW]
  3799. result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], OD, OC * N); // [N*OC, OD, OH, OW]
  3800. return result;
  3801. }
  3802. // ggml_conv_2d_sk_p0
  3803. struct ggml_tensor * ggml_conv_2d_sk_p0(
  3804. struct ggml_context * ctx,
  3805. struct ggml_tensor * a,
  3806. struct ggml_tensor * b) {
  3807. return ggml_conv_2d(ctx, a, b, a->ne[0], a->ne[1], 0, 0, 1, 1);
  3808. }
  3809. // ggml_conv_2d_s1_ph
  3810. struct ggml_tensor * ggml_conv_2d_s1_ph(
  3811. struct ggml_context * ctx,
  3812. struct ggml_tensor * a,
  3813. struct ggml_tensor * b) {
  3814. return ggml_conv_2d(ctx, a, b, 1, 1, a->ne[0] / 2, a->ne[1] / 2, 1, 1);
  3815. }
  3816. // ggml_conv_2d_dw
  3817. struct ggml_tensor * ggml_conv_2d_dw(
  3818. struct ggml_context * ctx,
  3819. struct ggml_tensor * a,
  3820. struct ggml_tensor * b,
  3821. int s0,
  3822. int s1,
  3823. int p0,
  3824. int p1,
  3825. int d0,
  3826. int d1) {
  3827. struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]);
  3828. struct ggml_tensor * im2col = ggml_im2col(ctx, new_a,
  3829. ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]),
  3830. s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW]
  3831. struct ggml_tensor * new_b = ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW]
  3832. new_a = ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1); // [OC,1, KH, KW] => [1, OC, 1, KH * KW]
  3833. struct ggml_tensor * result = ggml_mul_mat(ctx, new_a, new_b);
  3834. result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW]
  3835. return result;
  3836. }
  3837. // ggml_conv_2d_dw_direct
  3838. struct ggml_tensor * ggml_conv_2d_dw_direct(
  3839. struct ggml_context * ctx,
  3840. struct ggml_tensor * a,
  3841. struct ggml_tensor * b,
  3842. int stride0,
  3843. int stride1,
  3844. int pad0,
  3845. int pad1,
  3846. int dilation0,
  3847. int dilation1) {
  3848. GGML_ASSERT(a->ne[2] == 1);
  3849. GGML_ASSERT(a->ne[3] == b->ne[2]);
  3850. int64_t ne[4];
  3851. ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], stride0, pad0, dilation0);
  3852. ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], stride1, pad1, dilation1);
  3853. ne[2] = b->ne[2];
  3854. ne[3] = b->ne[3];
  3855. struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne);
  3856. if (ggml_is_contiguous_channels(b)) {
  3857. // Result will be permuted the same way as input (CWHN order)
  3858. const int64_t type_size = ggml_type_size(result->type);
  3859. GGML_ASSERT(ggml_blck_size(result->type) == 1);
  3860. result->nb[0] = result->ne[2] * type_size;
  3861. result->nb[1] = result->ne[0] * result->nb[0];
  3862. result->nb[2] = type_size;
  3863. }
  3864. int32_t params[] = { stride0, stride1, pad0, pad1, dilation0, dilation1 };
  3865. ggml_set_op_params(result, params, sizeof(params));
  3866. result->op = GGML_OP_CONV_2D_DW;
  3867. result->src[0] = a;
  3868. result->src[1] = b;
  3869. return result;
  3870. }
  3871. // ggml_conv_2d_direct
  3872. struct ggml_tensor * ggml_conv_2d_direct(
  3873. struct ggml_context * ctx,
  3874. struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC]
  3875. struct ggml_tensor * b, // input data [W, H, C, N]
  3876. int s0, // stride dimension 0
  3877. int s1, // stride dimension 1
  3878. int p0, // padding dimension 0
  3879. int p1, // padding dimension 1
  3880. int d0, // dilation dimension 0
  3881. int d1) {// dilation dimension 1
  3882. GGML_ASSERT(a->ne[2] == b->ne[2]);
  3883. //GGML_ASSERT(a->type == b->type);
  3884. int64_t ne[4];
  3885. ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
  3886. ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
  3887. ne[2] = a->ne[3];
  3888. ne[3] = b->ne[3];
  3889. struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne);
  3890. ggml_set_op_params_i32(result, 0, s0);
  3891. ggml_set_op_params_i32(result, 1, s1);
  3892. ggml_set_op_params_i32(result, 2, p0);
  3893. ggml_set_op_params_i32(result, 3, p1);
  3894. ggml_set_op_params_i32(result, 4, d0);
  3895. ggml_set_op_params_i32(result, 5, d1);
  3896. result->op = GGML_OP_CONV_2D;
  3897. result->src[0] = a;
  3898. result->src[1] = b;
  3899. return result;
  3900. }
  3901. // ggml_conv_3d_direct
  3902. struct ggml_tensor * ggml_conv_3d_direct(
  3903. struct ggml_context * ctx,
  3904. struct ggml_tensor * a,
  3905. struct ggml_tensor * b,
  3906. int s0,
  3907. int s1,
  3908. int s2,
  3909. int p0,
  3910. int p1,
  3911. int p2,
  3912. int d0,
  3913. int d1,
  3914. int d2,
  3915. int c,
  3916. int n,
  3917. int oc) {
  3918. GGML_ASSERT(a->ne[3] == (int64_t) c * oc);
  3919. GGML_ASSERT(b->ne[3] == (int64_t) c * n);
  3920. int64_t ne[4];
  3921. ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
  3922. ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
  3923. ne[2] = ggml_calc_conv_output_size(b->ne[2], a->ne[2], s2, p2, d2);
  3924. ne[3] = (int64_t) oc * n;
  3925. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
  3926. ggml_set_op_params_i32(result, 0, s0);
  3927. ggml_set_op_params_i32(result, 1, s1);
  3928. ggml_set_op_params_i32(result, 2, s2);
  3929. ggml_set_op_params_i32(result, 3, p0);
  3930. ggml_set_op_params_i32(result, 4, p1);
  3931. ggml_set_op_params_i32(result, 5, p2);
  3932. ggml_set_op_params_i32(result, 6, d0);
  3933. ggml_set_op_params_i32(result, 7, d1);
  3934. ggml_set_op_params_i32(result, 8, d2);
  3935. ggml_set_op_params_i32(result, 9, c);
  3936. ggml_set_op_params_i32(result, 10, n);
  3937. ggml_set_op_params_i32(result, 11, oc);
  3938. result->op = GGML_OP_CONV_3D;
  3939. result->src[0] = a;
  3940. result->src[1] = b;
  3941. return result;
  3942. }
  3943. // ggml_conv_transpose_2d_p0
  3944. static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
  3945. return (ins - 1) * s - 2 * p + ks;
  3946. }
  3947. struct ggml_tensor * ggml_conv_transpose_2d_p0(
  3948. struct ggml_context * ctx,
  3949. struct ggml_tensor * a,
  3950. struct ggml_tensor * b,
  3951. int stride) {
  3952. GGML_ASSERT(a->ne[3] == b->ne[2]);
  3953. const int64_t ne[4] = {
  3954. ggml_calc_conv_transpose_output_size(b->ne[0], a->ne[0], stride, 0 /*p0*/),
  3955. ggml_calc_conv_transpose_output_size(b->ne[1], a->ne[1], stride, 0 /*p1*/),
  3956. a->ne[2], b->ne[3],
  3957. };
  3958. struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
  3959. ggml_set_op_params_i32(result, 0, stride);
  3960. result->op = GGML_OP_CONV_TRANSPOSE_2D;
  3961. result->src[0] = a;
  3962. result->src[1] = b;
  3963. return result;
  3964. }
  3965. // ggml_pool_*
  3966. static int64_t ggml_calc_pool_output_size(int64_t ins, int ks, int s, float p) {
  3967. return (ins + 2 * p - ks) / s + 1;
  3968. }
  3969. // ggml_pool_1d
  3970. struct ggml_tensor * ggml_pool_1d(
  3971. struct ggml_context * ctx,
  3972. struct ggml_tensor * a,
  3973. enum ggml_op_pool op,
  3974. int k0,
  3975. int s0,
  3976. int p0) {
  3977. const int64_t ne[4] = {
  3978. ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
  3979. a->ne[1],
  3980. a->ne[2],
  3981. a->ne[3],
  3982. };
  3983. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
  3984. int32_t params[] = { op, k0, s0, p0 };
  3985. ggml_set_op_params(result, params, sizeof(params));
  3986. result->op = GGML_OP_POOL_1D;
  3987. result->src[0] = a;
  3988. return result;
  3989. }
  3990. // ggml_pool_2d
  3991. struct ggml_tensor * ggml_pool_2d(
  3992. struct ggml_context * ctx,
  3993. struct ggml_tensor * a,
  3994. enum ggml_op_pool op,
  3995. int k0,
  3996. int k1,
  3997. int s0,
  3998. int s1,
  3999. float p0,
  4000. float p1) {
  4001. struct ggml_tensor * result;
  4002. const int64_t ne[4] = {
  4003. ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
  4004. ggml_calc_pool_output_size(a->ne[1], k1, s1, p1),
  4005. a->ne[2],
  4006. a->ne[3],
  4007. };
  4008. result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
  4009. int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
  4010. ggml_set_op_params(result, params, sizeof(params));
  4011. result->op = GGML_OP_POOL_2D;
  4012. result->src[0] = a;
  4013. return result;
  4014. }
  4015. struct ggml_tensor * ggml_pool_2d_back(
  4016. struct ggml_context * ctx,
  4017. struct ggml_tensor * a,
  4018. struct ggml_tensor * af,
  4019. enum ggml_op_pool op,
  4020. int k0,
  4021. int k1,
  4022. int s0,
  4023. int s1,
  4024. float p0,
  4025. float p1) {
  4026. struct ggml_tensor * result;
  4027. result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, af->ne);
  4028. int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
  4029. ggml_set_op_params(result, params, sizeof(params));
  4030. result->op = GGML_OP_POOL_2D_BACK;
  4031. result->src[0] = a;
  4032. result->src[1] = af;
  4033. return result;
  4034. }
  4035. // ggml_upscale / ggml_interpolate
  4036. static struct ggml_tensor * ggml_interpolate_impl(
  4037. struct ggml_context * ctx,
  4038. struct ggml_tensor * a,
  4039. int64_t ne0,
  4040. int64_t ne1,
  4041. int64_t ne2,
  4042. int64_t ne3,
  4043. uint32_t mode) {
  4044. GGML_ASSERT((mode & 0xFF) < GGML_SCALE_MODE_COUNT);
  4045. // TODO: implement antialias for modes other than bilinear
  4046. GGML_ASSERT(!(mode & GGML_SCALE_FLAG_ANTIALIAS) || (mode & 0xFF) == GGML_SCALE_MODE_BILINEAR);
  4047. struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
  4048. ggml_set_op_params_i32(result, 0, (int32_t)mode);
  4049. result->op = GGML_OP_UPSCALE;
  4050. result->src[0] = a;
  4051. return result;
  4052. }
  4053. struct ggml_tensor * ggml_upscale(
  4054. struct ggml_context * ctx,
  4055. struct ggml_tensor * a,
  4056. int scale_factor,
  4057. enum ggml_scale_mode mode) {
  4058. GGML_ASSERT(scale_factor > 1);
  4059. return ggml_interpolate_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3], mode);
  4060. }
  4061. struct ggml_tensor * ggml_upscale_ext(
  4062. struct ggml_context * ctx,
  4063. struct ggml_tensor * a,
  4064. int ne0,
  4065. int ne1,
  4066. int ne2,
  4067. int ne3,
  4068. enum ggml_scale_mode mode) {
  4069. return ggml_interpolate_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
  4070. }
  4071. struct ggml_tensor * ggml_interpolate(
  4072. struct ggml_context * ctx,
  4073. struct ggml_tensor * a,
  4074. int64_t ne0,
  4075. int64_t ne1,
  4076. int64_t ne2,
  4077. int64_t ne3,
  4078. uint32_t mode) {
  4079. return ggml_interpolate_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
  4080. }
  4081. // ggml_pad
  4082. struct ggml_tensor * ggml_pad(
  4083. struct ggml_context * ctx,
  4084. struct ggml_tensor * a,
  4085. int p0,
  4086. int p1,
  4087. int p2,
  4088. int p3) {
  4089. return ggml_pad_ext(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3);
  4090. }
  4091. // ggml_pad_circular
  4092. struct ggml_tensor * ggml_pad_circular(
  4093. struct ggml_context * ctx,
  4094. struct ggml_tensor * a,
  4095. int p0,
  4096. int p1,
  4097. int p2,
  4098. int p3) {
  4099. return ggml_pad_ext_circular(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3);
  4100. }
  4101. struct ggml_tensor * ggml_pad_ext(
  4102. struct ggml_context * ctx,
  4103. struct ggml_tensor * a,
  4104. int lp0,
  4105. int rp0,
  4106. int lp1,
  4107. int rp1,
  4108. int lp2,
  4109. int rp2,
  4110. int lp3,
  4111. int rp3
  4112. ) {
  4113. struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
  4114. a->ne[0] + lp0 + rp0,
  4115. a->ne[1] + lp1 + rp1,
  4116. a->ne[2] + lp2 + rp2,
  4117. a->ne[3] + lp3 + rp3);
  4118. ggml_set_op_params_i32(result, 0, lp0);
  4119. ggml_set_op_params_i32(result, 1, rp0);
  4120. ggml_set_op_params_i32(result, 2, lp1);
  4121. ggml_set_op_params_i32(result, 3, rp1);
  4122. ggml_set_op_params_i32(result, 4, lp2);
  4123. ggml_set_op_params_i32(result, 5, rp2);
  4124. ggml_set_op_params_i32(result, 6, lp3);
  4125. ggml_set_op_params_i32(result, 7, rp3);
  4126. ggml_set_op_params_i32(result, 8, 0); // not circular by default
  4127. result->op = GGML_OP_PAD;
  4128. result->src[0] = a;
  4129. return result;
  4130. }
  4131. // ggml_pad_ext_circular
  4132. struct ggml_tensor * ggml_pad_ext_circular(
  4133. struct ggml_context * ctx,
  4134. struct ggml_tensor * a,
  4135. int lp0,
  4136. int rp0,
  4137. int lp1,
  4138. int rp1,
  4139. int lp2,
  4140. int rp2,
  4141. int lp3,
  4142. int rp3
  4143. ) {
  4144. struct ggml_tensor * result = ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
  4145. ggml_set_op_params_i32(result, 8, 1); // circular
  4146. return result;
  4147. }
  4148. // ggml_pad_reflect_1d
  4149. struct ggml_tensor * ggml_pad_reflect_1d(
  4150. struct ggml_context * ctx,
  4151. struct ggml_tensor * a,
  4152. int p0,
  4153. int p1) {
  4154. GGML_ASSERT(p0 >= 0);
  4155. GGML_ASSERT(p1 >= 0);
  4156. GGML_ASSERT(p0 < a->ne[0]); // padding length on each size must be less than the
  4157. GGML_ASSERT(p1 < a->ne[0]); // existing length of the dimension being padded
  4158. GGML_ASSERT(ggml_is_contiguous(a));
  4159. GGML_ASSERT(a->type == GGML_TYPE_F32);
  4160. struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
  4161. a->ne[0] + p0 + p1,
  4162. a->ne[1],
  4163. a->ne[2],
  4164. a->ne[3]);
  4165. int32_t params[] = { p0, p1 };
  4166. ggml_set_op_params(result, params, sizeof(params));
  4167. result->op = GGML_OP_PAD_REFLECT_1D;
  4168. result->src[0] = a;
  4169. return result;
  4170. }
  4171. // ggml_roll
  4172. struct ggml_tensor * ggml_roll(
  4173. struct ggml_context * ctx,
  4174. struct ggml_tensor * a,
  4175. int shift0,
  4176. int shift1,
  4177. int shift2,
  4178. int shift3) {
  4179. GGML_ASSERT(a->nb[0] == ggml_type_size(a->type));
  4180. GGML_ASSERT(abs(shift0) < a->ne[0]);
  4181. GGML_ASSERT(abs(shift1) < a->ne[1]);
  4182. GGML_ASSERT(abs(shift2) < a->ne[2]);
  4183. GGML_ASSERT(abs(shift3) < a->ne[3]);
  4184. struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
  4185. ggml_set_op_params_i32(result, 0, shift0);
  4186. ggml_set_op_params_i32(result, 1, shift1);
  4187. ggml_set_op_params_i32(result, 2, shift2);
  4188. ggml_set_op_params_i32(result, 3, shift3);
  4189. result->op = GGML_OP_ROLL;
  4190. result->src[0] = a;
  4191. return result;
  4192. }
  4193. // ggml_timestep_embedding
  4194. struct ggml_tensor * ggml_timestep_embedding(
  4195. struct ggml_context * ctx,
  4196. struct ggml_tensor * timesteps,
  4197. int dim,
  4198. int max_period) {
  4199. struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, dim, timesteps->ne[0]);
  4200. ggml_set_op_params_i32(result, 0, dim);
  4201. ggml_set_op_params_i32(result, 1, max_period);
  4202. result->op = GGML_OP_TIMESTEP_EMBEDDING;
  4203. result->src[0] = timesteps;
  4204. return result;
  4205. }
  4206. // ggml_tri
  4207. struct ggml_tensor * ggml_tri(
  4208. struct ggml_context * ctx,
  4209. struct ggml_tensor * a,
  4210. enum ggml_tri_type type) {
  4211. GGML_ASSERT(a->type == GGML_TYPE_F32);
  4212. GGML_ASSERT(ggml_is_contiguous(a));
  4213. GGML_ASSERT(a->ne[0] == a->ne[1]);
  4214. struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
  4215. ggml_set_op_params_i32(result, 0, type);
  4216. result->op = GGML_OP_TRI;
  4217. result->src[0] = a;
  4218. return result;
  4219. }
  4220. // ggml_fill
  4221. static struct ggml_tensor * ggml_fill_impl(
  4222. struct ggml_context * ctx,
  4223. struct ggml_tensor * a,
  4224. float c,
  4225. bool inplace) {
  4226. GGML_ASSERT(a->type == GGML_TYPE_F32);
  4227. GGML_ASSERT(ggml_is_contiguous(a));
  4228. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  4229. ggml_set_op_params_f32(result, 0, c);
  4230. result->op = GGML_OP_FILL;
  4231. result->src[0] = a;
  4232. return result;
  4233. }
  4234. struct ggml_tensor * ggml_fill(
  4235. struct ggml_context * ctx,
  4236. struct ggml_tensor * a,
  4237. float c) {
  4238. return ggml_fill_impl(ctx, a, c, false);
  4239. }
  4240. struct ggml_tensor * ggml_fill_inplace(
  4241. struct ggml_context * ctx,
  4242. struct ggml_tensor * a,
  4243. float c) {
  4244. return ggml_fill_impl(ctx, a, c, true);
  4245. }
  4246. // ggml_argsort
  4247. struct ggml_tensor * ggml_argsort(
  4248. struct ggml_context * ctx,
  4249. struct ggml_tensor * a,
  4250. enum ggml_sort_order order) {
  4251. GGML_ASSERT(a->ne[0] <= INT32_MAX);
  4252. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
  4253. ggml_set_op_params_i32(result, 0, (int32_t) order);
  4254. result->op = GGML_OP_ARGSORT;
  4255. result->src[0] = a;
  4256. return result;
  4257. }
  4258. // ggml_argsort_top_k
  4259. struct ggml_tensor * ggml_argsort_top_k(
  4260. struct ggml_context * ctx,
  4261. struct ggml_tensor * a,
  4262. int k) {
  4263. GGML_ASSERT(a->ne[0] >= k);
  4264. struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC);
  4265. result = ggml_view_4d(ctx, result,
  4266. k, result->ne[1], result->ne[2], result->ne[3],
  4267. result->nb[1], result->nb[2], result->nb[3],
  4268. 0);
  4269. return result;
  4270. }
  4271. // ggml_top_k
  4272. struct ggml_tensor * ggml_top_k(
  4273. struct ggml_context * ctx,
  4274. struct ggml_tensor * a,
  4275. int k) {
  4276. GGML_ASSERT(a->ne[0] >= k);
  4277. struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_I32, k, a->ne[1], a->ne[2], a->ne[3]);
  4278. result->op = GGML_OP_TOP_K;
  4279. result->src[0] = a;
  4280. return result;
  4281. }
  4282. // ggml_arange
  4283. struct ggml_tensor * ggml_arange(
  4284. struct ggml_context * ctx,
  4285. float start,
  4286. float stop,
  4287. float step) {
  4288. GGML_ASSERT(stop > start);
  4289. const int64_t steps = (int64_t) ceilf((stop - start) / step);
  4290. struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps);
  4291. ggml_set_op_params_f32(result, 0, start);
  4292. ggml_set_op_params_f32(result, 1, stop);
  4293. ggml_set_op_params_f32(result, 2, step);
  4294. result->op = GGML_OP_ARANGE;
  4295. return result;
  4296. }
  4297. // ggml_flash_attn_ext
  4298. struct ggml_tensor * ggml_flash_attn_ext(
  4299. struct ggml_context * ctx,
  4300. struct ggml_tensor * q,
  4301. struct ggml_tensor * k,
  4302. struct ggml_tensor * v,
  4303. struct ggml_tensor * mask,
  4304. float scale,
  4305. float max_bias,
  4306. float logit_softcap) {
  4307. GGML_ASSERT(ggml_can_mul_mat(k, q));
  4308. // TODO: check if vT can be multiplied by (k*qT)
  4309. GGML_ASSERT(q->ne[3] == k->ne[3]);
  4310. GGML_ASSERT(q->ne[3] == v->ne[3]);
  4311. if (mask) {
  4312. GGML_ASSERT(ggml_is_contiguous(mask));
  4313. //GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
  4314. GGML_ASSERT(q->ne[2] % mask->ne[2] == 0);
  4315. GGML_ASSERT(q->ne[3] % mask->ne[3] == 0);
  4316. }
  4317. if (max_bias > 0.0f) {
  4318. GGML_ASSERT(mask);
  4319. }
  4320. // permute(0, 2, 1, 3)
  4321. int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] };
  4322. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
  4323. float params[] = { scale, max_bias, logit_softcap };
  4324. ggml_set_op_params(result, params, sizeof(params));
  4325. result->op = GGML_OP_FLASH_ATTN_EXT;
  4326. result->src[0] = q;
  4327. result->src[1] = k;
  4328. result->src[2] = v;
  4329. result->src[3] = mask;
  4330. return result;
  4331. }
  4332. void ggml_flash_attn_ext_set_prec(
  4333. struct ggml_tensor * a,
  4334. enum ggml_prec prec) {
  4335. GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
  4336. const int32_t prec_i32 = (int32_t) prec;
  4337. ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
  4338. }
  4339. enum ggml_prec ggml_flash_attn_ext_get_prec(
  4340. const struct ggml_tensor * a) {
  4341. GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
  4342. const int32_t prec_i32 = ggml_get_op_params_i32(a, 3);
  4343. return (enum ggml_prec) prec_i32;
  4344. }
  4345. void ggml_flash_attn_ext_add_sinks(
  4346. struct ggml_tensor * a,
  4347. struct ggml_tensor * sinks) {
  4348. if (!sinks) {
  4349. a->src[4] = NULL;
  4350. return;
  4351. }
  4352. GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
  4353. GGML_ASSERT(a->src[4] == NULL);
  4354. GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]);
  4355. GGML_ASSERT(sinks->type == GGML_TYPE_F32);
  4356. a->src[4] = sinks;
  4357. }
  4358. // ggml_flash_attn_back
  4359. struct ggml_tensor * ggml_flash_attn_back(
  4360. struct ggml_context * ctx,
  4361. struct ggml_tensor * q,
  4362. struct ggml_tensor * k,
  4363. struct ggml_tensor * v,
  4364. struct ggml_tensor * d,
  4365. bool masked) {
  4366. GGML_ABORT("TODO: adapt to ggml_flash_attn_ext() changes");
  4367. GGML_ASSERT(ggml_can_mul_mat(k, q));
  4368. // TODO: check if vT can be multiplied by (k*qT)
  4369. // d shape [D,N,ne2,ne3]
  4370. // q shape [D,N,ne2,ne3]
  4371. // k shape [D,M,kvne2,ne3]
  4372. // v shape [M,D,kvne2,ne3]
  4373. const int64_t D = q->ne[0];
  4374. const int64_t N = q->ne[1];
  4375. const int64_t M = k->ne[1];
  4376. const int64_t ne2 = q->ne[2];
  4377. const int64_t ne3 = q->ne[3];
  4378. const int64_t kvne2 = k->ne[2];
  4379. GGML_ASSERT(k->ne[0] == D);
  4380. GGML_ASSERT(v->ne[0] == M);
  4381. GGML_ASSERT(v->ne[1] == D);
  4382. GGML_ASSERT(d->ne[0] == D);
  4383. GGML_ASSERT(d->ne[1] == N);
  4384. GGML_ASSERT(k->ne[2] == kvne2);
  4385. GGML_ASSERT(k->ne[3] == ne3);
  4386. GGML_ASSERT(v->ne[2] == kvne2);
  4387. GGML_ASSERT(v->ne[3] == ne3);
  4388. GGML_ASSERT(d->ne[2] == ne2);
  4389. GGML_ASSERT(d->ne[3] == ne3);
  4390. GGML_ASSERT(ne2 % kvne2 == 0);
  4391. // store gradients of q, k and v as continuous tensors concatenated in result.
  4392. // note: v and gradv are actually transposed, i.e. v->ne[0] != D.
  4393. const int64_t elem_q = ggml_nelements(q);
  4394. const int64_t elem_k = ggml_nelements(k);
  4395. const int64_t elem_v = ggml_nelements(v);
  4396. enum ggml_type result_type = GGML_TYPE_F32;
  4397. GGML_ASSERT(ggml_blck_size(result_type) == 1);
  4398. const size_t tsize = ggml_type_size(result_type);
  4399. const size_t offs_q = 0;
  4400. const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
  4401. const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
  4402. const size_t end = offs_v + GGML_PAD(elem_v * tsize, GGML_MEM_ALIGN);
  4403. const size_t nelements = (end + tsize - 1)/tsize;
  4404. struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nelements);
  4405. int32_t masked_i = masked ? 1 : 0;
  4406. ggml_set_op_params(result, &masked_i, sizeof(masked_i));
  4407. result->op = GGML_OP_FLASH_ATTN_BACK;
  4408. result->src[0] = q;
  4409. result->src[1] = k;
  4410. result->src[2] = v;
  4411. result->src[3] = d;
  4412. return result;
  4413. }
  4414. // ggml_ssm_conv
  4415. struct ggml_tensor * ggml_ssm_conv(
  4416. struct ggml_context * ctx,
  4417. struct ggml_tensor * sx,
  4418. struct ggml_tensor * c) {
  4419. GGML_ASSERT(ggml_is_3d(sx));
  4420. GGML_ASSERT(ggml_is_matrix(c));
  4421. const int64_t d_conv = c->ne[0];
  4422. const int64_t d_inner = c->ne[1];
  4423. const int64_t n_t = sx->ne[0] - d_conv + 1; // tokens per sequence
  4424. const int64_t n_s = sx->ne[2];
  4425. // TODO: maybe support other strides than 1?
  4426. GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
  4427. GGML_ASSERT(sx->ne[1] == d_inner);
  4428. GGML_ASSERT(n_t >= 0);
  4429. struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_t, n_s);
  4430. result->op = GGML_OP_SSM_CONV;
  4431. result->src[0] = sx;
  4432. result->src[1] = c;
  4433. return result;
  4434. }
  4435. // ggml_ssm_scan
  4436. struct ggml_tensor * ggml_ssm_scan(
  4437. struct ggml_context * ctx,
  4438. struct ggml_tensor * s,
  4439. struct ggml_tensor * x,
  4440. struct ggml_tensor * dt,
  4441. struct ggml_tensor * A,
  4442. struct ggml_tensor * B,
  4443. struct ggml_tensor * C,
  4444. struct ggml_tensor * ids) {
  4445. GGML_ASSERT(ggml_is_contiguous(s));
  4446. GGML_ASSERT(ggml_is_contiguous(dt));
  4447. GGML_ASSERT(ggml_is_contiguous(A));
  4448. GGML_ASSERT(x->nb[0] == ggml_type_size(x->type));
  4449. GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
  4450. GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
  4451. GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]);
  4452. GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]);
  4453. GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]);
  4454. GGML_ASSERT(ggml_are_same_shape(B, C));
  4455. GGML_ASSERT(ids->type == GGML_TYPE_I32);
  4456. {
  4457. const int64_t d_state = s->ne[0];
  4458. const int64_t head_dim = x->ne[0];
  4459. const int64_t n_head = x->ne[1];
  4460. const int64_t n_seq_tokens = x->ne[2];
  4461. const int64_t n_seqs = x->ne[3];
  4462. GGML_ASSERT(dt->ne[0] == n_head);
  4463. GGML_ASSERT(dt->ne[1] == n_seq_tokens);
  4464. GGML_ASSERT(dt->ne[2] == n_seqs);
  4465. GGML_ASSERT(ggml_is_3d(dt));
  4466. GGML_ASSERT(s->ne[1] == head_dim);
  4467. GGML_ASSERT(s->ne[2] == n_head);
  4468. GGML_ASSERT(B->ne[0] == d_state);
  4469. GGML_ASSERT(B->ne[2] == n_seq_tokens);
  4470. GGML_ASSERT(B->ne[3] == n_seqs);
  4471. GGML_ASSERT(ids->ne[0] == n_seqs);
  4472. GGML_ASSERT(ggml_is_vector(ids));
  4473. GGML_ASSERT(A->ne[1] == n_head);
  4474. GGML_ASSERT(ggml_is_matrix(A));
  4475. if (A->ne[0] != 1) {
  4476. // Mamba-1 has more granular decay factors
  4477. GGML_ASSERT(A->ne[0] == d_state);
  4478. }
  4479. }
  4480. // concatenated y + ssm_states
  4481. struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + s->ne[0]*s->ne[1]*s->ne[2]*ids->ne[0]);
  4482. result->op = GGML_OP_SSM_SCAN;
  4483. result->src[0] = s;
  4484. result->src[1] = x;
  4485. result->src[2] = dt;
  4486. result->src[3] = A;
  4487. result->src[4] = B;
  4488. result->src[5] = C;
  4489. result->src[6] = ids;
  4490. return result;
  4491. }
  4492. // ggml_win_part
  4493. struct ggml_tensor * ggml_win_part(
  4494. struct ggml_context * ctx,
  4495. struct ggml_tensor * a,
  4496. int w) {
  4497. GGML_ASSERT(a->ne[3] == 1);
  4498. GGML_ASSERT(a->type == GGML_TYPE_F32);
  4499. // padding
  4500. const int px = (w - a->ne[1]%w)%w;
  4501. const int py = (w - a->ne[2]%w)%w;
  4502. const int npx = (px + a->ne[1])/w;
  4503. const int npy = (py + a->ne[2])/w;
  4504. const int np = npx*npy;
  4505. const int64_t ne[4] = { a->ne[0], w, w, np, };
  4506. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
  4507. int32_t params[] = { npx, npy, w };
  4508. ggml_set_op_params(result, params, sizeof(params));
  4509. result->op = GGML_OP_WIN_PART;
  4510. result->src[0] = a;
  4511. return result;
  4512. }
  4513. // ggml_win_unpart
  4514. struct ggml_tensor * ggml_win_unpart(
  4515. struct ggml_context * ctx,
  4516. struct ggml_tensor * a,
  4517. int w0,
  4518. int h0,
  4519. int w) {
  4520. GGML_ASSERT(a->type == GGML_TYPE_F32);
  4521. const int64_t ne[4] = { a->ne[0], w0, h0, 1, };
  4522. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
  4523. int32_t params[] = { w };
  4524. ggml_set_op_params(result, params, sizeof(params));
  4525. result->op = GGML_OP_WIN_UNPART;
  4526. result->src[0] = a;
  4527. return result;
  4528. }
  4529. // ggml_get_rel_pos
  4530. struct ggml_tensor * ggml_get_rel_pos(
  4531. struct ggml_context * ctx,
  4532. struct ggml_tensor * a,
  4533. int qh,
  4534. int kh) {
  4535. GGML_ASSERT(qh == kh);
  4536. GGML_ASSERT(2*MAX(qh, kh) - 1 == a->ne[1]);
  4537. const int64_t ne[4] = { a->ne[0], kh, qh, 1, };
  4538. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 3, ne);
  4539. result->op = GGML_OP_GET_REL_POS;
  4540. result->src[0] = a;
  4541. return result;
  4542. }
  4543. // ggml_add_rel_pos
  4544. static struct ggml_tensor * ggml_add_rel_pos_impl(
  4545. struct ggml_context * ctx,
  4546. struct ggml_tensor * a,
  4547. struct ggml_tensor * pw,
  4548. struct ggml_tensor * ph,
  4549. bool inplace) {
  4550. GGML_ASSERT(ggml_are_same_shape(pw, ph));
  4551. GGML_ASSERT(ggml_is_contiguous(a));
  4552. GGML_ASSERT(ggml_is_contiguous(pw));
  4553. GGML_ASSERT(ggml_is_contiguous(ph));
  4554. GGML_ASSERT(ph->type == GGML_TYPE_F32);
  4555. GGML_ASSERT(pw->type == GGML_TYPE_F32);
  4556. GGML_ASSERT(pw->ne[3] == a->ne[2]);
  4557. GGML_ASSERT(pw->ne[0]*pw->ne[0] == a->ne[0]);
  4558. GGML_ASSERT(pw->ne[1]*pw->ne[2] == a->ne[1]);
  4559. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  4560. ggml_set_op_params_i32(result, 0, inplace ? 1 : 0);
  4561. result->op = GGML_OP_ADD_REL_POS;
  4562. result->src[0] = a;
  4563. result->src[1] = pw;
  4564. result->src[2] = ph;
  4565. return result;
  4566. }
  4567. struct ggml_tensor * ggml_add_rel_pos(
  4568. struct ggml_context * ctx,
  4569. struct ggml_tensor * a,
  4570. struct ggml_tensor * pw,
  4571. struct ggml_tensor * ph) {
  4572. return ggml_add_rel_pos_impl(ctx, a, pw, ph, false);
  4573. }
  4574. struct ggml_tensor * ggml_add_rel_pos_inplace(
  4575. struct ggml_context * ctx,
  4576. struct ggml_tensor * a,
  4577. struct ggml_tensor * pw,
  4578. struct ggml_tensor * ph) {
  4579. return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
  4580. }
  4581. // ggml_rwkv_wkv6
  4582. struct ggml_tensor * ggml_rwkv_wkv6(
  4583. struct ggml_context * ctx,
  4584. struct ggml_tensor * k,
  4585. struct ggml_tensor * v,
  4586. struct ggml_tensor * r,
  4587. struct ggml_tensor * tf,
  4588. struct ggml_tensor * td,
  4589. struct ggml_tensor * state) {
  4590. GGML_ASSERT(ggml_is_contiguous(k));
  4591. GGML_ASSERT(ggml_is_contiguous(v));
  4592. GGML_ASSERT(ggml_is_contiguous(r));
  4593. GGML_ASSERT(ggml_is_contiguous(tf));
  4594. GGML_ASSERT(ggml_is_contiguous(td));
  4595. GGML_ASSERT(ggml_is_contiguous(state));
  4596. const int64_t S = k->ne[0];
  4597. const int64_t H = k->ne[1];
  4598. const int64_t n_tokens = k->ne[2];
  4599. const int64_t n_seqs = state->ne[1];
  4600. {
  4601. GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
  4602. GGML_ASSERT(r->ne[0] == S && r->ne[1] == H && r->ne[2] == n_tokens);
  4603. GGML_ASSERT(td->ne[0] == S && td->ne[1] == H && td->ne[2] == n_tokens);
  4604. GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
  4605. }
  4606. // concat output and new_state
  4607. const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
  4608. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
  4609. result->op = GGML_OP_RWKV_WKV6;
  4610. result->src[0] = k;
  4611. result->src[1] = v;
  4612. result->src[2] = r;
  4613. result->src[3] = tf;
  4614. result->src[4] = td;
  4615. result->src[5] = state;
  4616. return result;
  4617. }
  4618. // ggml_gated_linear_attn
  4619. struct ggml_tensor * ggml_gated_linear_attn(
  4620. struct ggml_context * ctx,
  4621. struct ggml_tensor * k,
  4622. struct ggml_tensor * v,
  4623. struct ggml_tensor * q,
  4624. struct ggml_tensor * g,
  4625. struct ggml_tensor * state,
  4626. float scale) {
  4627. GGML_ASSERT(ggml_is_contiguous(k));
  4628. GGML_ASSERT(ggml_is_contiguous(v));
  4629. GGML_ASSERT(ggml_is_contiguous(q));
  4630. GGML_ASSERT(ggml_is_contiguous(g));
  4631. GGML_ASSERT(ggml_is_contiguous(state));
  4632. const int64_t S = k->ne[0];
  4633. const int64_t H = k->ne[1];
  4634. const int64_t n_tokens = k->ne[2];
  4635. const int64_t n_seqs = state->ne[1];
  4636. {
  4637. GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
  4638. GGML_ASSERT(q->ne[0] == S && q->ne[1] == H && q->ne[2] == n_tokens);
  4639. GGML_ASSERT(g->ne[0] == S && g->ne[1] == H && g->ne[2] == n_tokens);
  4640. GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
  4641. }
  4642. // concat output and new_state
  4643. const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
  4644. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
  4645. ggml_set_op_params_f32(result, 0, scale);
  4646. result->op = GGML_OP_GATED_LINEAR_ATTN;
  4647. result->src[0] = k;
  4648. result->src[1] = v;
  4649. result->src[2] = q;
  4650. result->src[3] = g;
  4651. result->src[4] = state;
  4652. return result;
  4653. }
  4654. // ggml_rwkv_wkv7
  4655. struct ggml_tensor * ggml_rwkv_wkv7(
  4656. struct ggml_context * ctx,
  4657. struct ggml_tensor * r,
  4658. struct ggml_tensor * w,
  4659. struct ggml_tensor * k,
  4660. struct ggml_tensor * v,
  4661. struct ggml_tensor * a,
  4662. struct ggml_tensor * b,
  4663. struct ggml_tensor * state) {
  4664. GGML_ASSERT(ggml_is_contiguous(r));
  4665. GGML_ASSERT(ggml_is_contiguous(w));
  4666. GGML_ASSERT(ggml_is_contiguous(k));
  4667. GGML_ASSERT(ggml_is_contiguous(v));
  4668. GGML_ASSERT(ggml_is_contiguous(a));
  4669. GGML_ASSERT(ggml_is_contiguous(b));
  4670. GGML_ASSERT(ggml_is_contiguous(state));
  4671. const int64_t S = k->ne[0];
  4672. const int64_t H = k->ne[1];
  4673. const int64_t n_tokens = k->ne[2];
  4674. const int64_t n_seqs = state->ne[1];
  4675. {
  4676. GGML_ASSERT(w->ne[0] == S && w->ne[1] == H && w->ne[2] == n_tokens);
  4677. GGML_ASSERT(k->ne[0] == S && k->ne[1] == H && k->ne[2] == n_tokens);
  4678. GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
  4679. GGML_ASSERT(a->ne[0] == S && a->ne[1] == H && a->ne[2] == n_tokens);
  4680. GGML_ASSERT(b->ne[0] == S && b->ne[1] == H && b->ne[2] == n_tokens);
  4681. GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
  4682. }
  4683. // concat output and new_state
  4684. const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
  4685. struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
  4686. result->op = GGML_OP_RWKV_WKV7;
  4687. result->src[0] = r;
  4688. result->src[1] = w;
  4689. result->src[2] = k;
  4690. result->src[3] = v;
  4691. result->src[4] = a;
  4692. result->src[5] = b;
  4693. result->src[6] = state;
  4694. return result;
  4695. }
  4696. // ggml_unary
  4697. static struct ggml_tensor * ggml_unary_impl(
  4698. struct ggml_context * ctx,
  4699. struct ggml_tensor * a,
  4700. enum ggml_unary_op op,
  4701. bool inplace) {
  4702. GGML_ASSERT(ggml_is_contiguous_1(a));
  4703. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  4704. ggml_set_op_params_i32(result, 0, (int32_t) op);
  4705. result->op = GGML_OP_UNARY;
  4706. result->src[0] = a;
  4707. return result;
  4708. }
  4709. struct ggml_tensor * ggml_unary(
  4710. struct ggml_context * ctx,
  4711. struct ggml_tensor * a,
  4712. enum ggml_unary_op op) {
  4713. return ggml_unary_impl(ctx, a, op, false);
  4714. }
  4715. struct ggml_tensor * ggml_unary_inplace(
  4716. struct ggml_context * ctx,
  4717. struct ggml_tensor * a,
  4718. enum ggml_unary_op op) {
  4719. return ggml_unary_impl(ctx, a, op, true);
  4720. }
  4721. // ggml_map_custom1
  4722. static struct ggml_tensor * ggml_map_custom1_impl(
  4723. struct ggml_context * ctx,
  4724. struct ggml_tensor * a,
  4725. const ggml_custom1_op_t fun,
  4726. int n_tasks,
  4727. void * userdata,
  4728. bool inplace) {
  4729. GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0);
  4730. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  4731. struct ggml_map_custom1_op_params params = {
  4732. /*.fun =*/ fun,
  4733. /*.n_tasks =*/ n_tasks,
  4734. /*.userdata =*/ userdata
  4735. };
  4736. ggml_set_op_params(result, &params, sizeof(params));
  4737. result->op = GGML_OP_MAP_CUSTOM1;
  4738. result->src[0] = a;
  4739. return result;
  4740. }
  4741. struct ggml_tensor * ggml_map_custom1(
  4742. struct ggml_context * ctx,
  4743. struct ggml_tensor * a,
  4744. const ggml_custom1_op_t fun,
  4745. int n_tasks,
  4746. void * userdata) {
  4747. return ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, false);
  4748. }
  4749. struct ggml_tensor * ggml_map_custom1_inplace(
  4750. struct ggml_context * ctx,
  4751. struct ggml_tensor * a,
  4752. const ggml_custom1_op_t fun,
  4753. int n_tasks,
  4754. void * userdata) {
  4755. return ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, true);
  4756. }
  4757. // ggml_map_custom2
  4758. static struct ggml_tensor * ggml_map_custom2_impl(
  4759. struct ggml_context * ctx,
  4760. struct ggml_tensor * a,
  4761. struct ggml_tensor * b,
  4762. const ggml_custom2_op_t fun,
  4763. int n_tasks,
  4764. void * userdata,
  4765. bool inplace) {
  4766. GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0);
  4767. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  4768. struct ggml_map_custom2_op_params params = {
  4769. /*.fun =*/ fun,
  4770. /*.n_tasks =*/ n_tasks,
  4771. /*.userdata =*/ userdata
  4772. };
  4773. ggml_set_op_params(result, &params, sizeof(params));
  4774. result->op = GGML_OP_MAP_CUSTOM2;
  4775. result->src[0] = a;
  4776. result->src[1] = b;
  4777. return result;
  4778. }
  4779. struct ggml_tensor * ggml_map_custom2(
  4780. struct ggml_context * ctx,
  4781. struct ggml_tensor * a,
  4782. struct ggml_tensor * b,
  4783. const ggml_custom2_op_t fun,
  4784. int n_tasks,
  4785. void * userdata) {
  4786. return ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, false);
  4787. }
  4788. struct ggml_tensor * ggml_map_custom2_inplace(
  4789. struct ggml_context * ctx,
  4790. struct ggml_tensor * a,
  4791. struct ggml_tensor * b,
  4792. const ggml_custom2_op_t fun,
  4793. int n_tasks,
  4794. void * userdata) {
  4795. return ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, true);
  4796. }
  4797. // ggml_map_custom3
  4798. static struct ggml_tensor * ggml_map_custom3_impl(
  4799. struct ggml_context * ctx,
  4800. struct ggml_tensor * a,
  4801. struct ggml_tensor * b,
  4802. struct ggml_tensor * c,
  4803. const ggml_custom3_op_t fun,
  4804. int n_tasks,
  4805. void * userdata,
  4806. bool inplace) {
  4807. GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0);
  4808. struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
  4809. struct ggml_map_custom3_op_params params = {
  4810. /*.fun =*/ fun,
  4811. /*.n_tasks =*/ n_tasks,
  4812. /*.userdata =*/ userdata
  4813. };
  4814. ggml_set_op_params(result, &params, sizeof(params));
  4815. result->op = GGML_OP_MAP_CUSTOM3;
  4816. result->src[0] = a;
  4817. result->src[1] = b;
  4818. result->src[2] = c;
  4819. return result;
  4820. }
  4821. struct ggml_tensor * ggml_map_custom3(
  4822. struct ggml_context * ctx,
  4823. struct ggml_tensor * a,
  4824. struct ggml_tensor * b,
  4825. struct ggml_tensor * c,
  4826. const ggml_custom3_op_t fun,
  4827. int n_tasks,
  4828. void * userdata) {
  4829. return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, false);
  4830. }
  4831. struct ggml_tensor * ggml_map_custom3_inplace(
  4832. struct ggml_context * ctx,
  4833. struct ggml_tensor * a,
  4834. struct ggml_tensor * b,
  4835. struct ggml_tensor * c,
  4836. const ggml_custom3_op_t fun,
  4837. int n_tasks,
  4838. void * userdata) {
  4839. return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true);
  4840. }
  4841. struct ggml_tensor * ggml_custom_4d(
  4842. struct ggml_context * ctx,
  4843. enum ggml_type type,
  4844. int64_t ne0,
  4845. int64_t ne1,
  4846. int64_t ne2,
  4847. int64_t ne3,
  4848. struct ggml_tensor ** args,
  4849. int n_args,
  4850. ggml_custom_op_t fun,
  4851. int n_tasks,
  4852. void * userdata) {
  4853. GGML_ASSERT(n_args < GGML_MAX_SRC);
  4854. struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, ne0, ne1, ne2, ne3);
  4855. struct ggml_custom_op_params params = {
  4856. /*.fun =*/ fun,
  4857. /*.n_tasks =*/ n_tasks,
  4858. /*.userdata =*/ userdata
  4859. };
  4860. ggml_set_op_params(result, &params, sizeof(params));
  4861. result->op = GGML_OP_CUSTOM;
  4862. for (int i = 0; i < n_args; i++) {
  4863. result->src[i] = args[i];
  4864. }
  4865. return result;
  4866. }
  4867. struct ggml_tensor * ggml_custom_inplace(
  4868. struct ggml_context * ctx,
  4869. struct ggml_tensor * a,
  4870. struct ggml_tensor ** args,
  4871. int n_args,
  4872. ggml_custom_op_t fun,
  4873. int n_tasks,
  4874. void * userdata) {
  4875. GGML_ASSERT(n_args < GGML_MAX_SRC - 1);
  4876. struct ggml_tensor * result = ggml_view_tensor(ctx, a);
  4877. struct ggml_custom_op_params params = {
  4878. /*.fun =*/ fun,
  4879. /*.n_tasks =*/ n_tasks,
  4880. /*.userdata =*/ userdata
  4881. };
  4882. ggml_set_op_params(result, &params, sizeof(params));
  4883. result->op = GGML_OP_CUSTOM;
  4884. result->src[0] = a;
  4885. for (int i = 0; i < n_args; i++) {
  4886. result->src[i + 1] = args[i];
  4887. }
  4888. return result;
  4889. }
  4890. // ggml_cross_entropy_loss
  4891. struct ggml_tensor * ggml_cross_entropy_loss(
  4892. struct ggml_context * ctx,
  4893. struct ggml_tensor * a,
  4894. struct ggml_tensor * b) {
  4895. GGML_ASSERT(ggml_are_same_shape(a, b));
  4896. struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);
  4897. result->op = GGML_OP_CROSS_ENTROPY_LOSS;
  4898. result->src[0] = a;
  4899. result->src[1] = b;
  4900. return result;
  4901. }
  4902. // ggml_cross_entropy_loss_back
  4903. struct ggml_tensor * ggml_cross_entropy_loss_back(
  4904. struct ggml_context * ctx,
  4905. struct ggml_tensor * a,
  4906. struct ggml_tensor * b,
  4907. struct ggml_tensor * c) {
  4908. GGML_ASSERT(ggml_is_scalar(a));
  4909. GGML_ASSERT(ggml_are_same_shape(b, c));
  4910. struct ggml_tensor * result = ggml_dup_tensor(ctx, b);
  4911. result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
  4912. result->src[0] = a;
  4913. result->src[1] = b;
  4914. result->src[2] = c;
  4915. return result;
  4916. }
  4917. // opt_step_adamw
  4918. struct ggml_tensor * ggml_opt_step_adamw(
  4919. struct ggml_context * ctx,
  4920. struct ggml_tensor * a,
  4921. struct ggml_tensor * grad,
  4922. struct ggml_tensor * m,
  4923. struct ggml_tensor * v,
  4924. struct ggml_tensor * adamw_params) {
  4925. GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);
  4926. GGML_ASSERT(ggml_are_same_shape(a, grad));
  4927. GGML_ASSERT(ggml_are_same_shape(a, m));
  4928. GGML_ASSERT(ggml_are_same_shape(a, v));
  4929. GGML_ASSERT(adamw_params->type == GGML_TYPE_F32);
  4930. GGML_ASSERT(ggml_nelements(adamw_params) == 7);
  4931. struct ggml_tensor * result = ggml_view_tensor(ctx, a);
  4932. result->op = GGML_OP_OPT_STEP_ADAMW;
  4933. result->src[0] = a;
  4934. result->src[1] = grad;
  4935. result->src[2] = m;
  4936. result->src[3] = v;
  4937. result->src[4] = adamw_params;
  4938. return result;
  4939. }
  4940. // opt_step_sgd
  4941. struct ggml_tensor * ggml_opt_step_sgd(
  4942. struct ggml_context * ctx,
  4943. struct ggml_tensor * a,
  4944. struct ggml_tensor * grad,
  4945. struct ggml_tensor * params) {
  4946. GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);
  4947. GGML_ASSERT(ggml_are_same_shape(a, grad));
  4948. GGML_ASSERT(params->type == GGML_TYPE_F32);
  4949. GGML_ASSERT(ggml_nelements(params) == 2);
  4950. struct ggml_tensor * result = ggml_view_tensor(ctx, a);
  4951. result->op = GGML_OP_OPT_STEP_SGD;
  4952. result->src[0] = a;
  4953. result->src[1] = grad;
  4954. result->src[2] = params;
  4955. return result;
  4956. }
  4957. // solve_tri
  4958. struct ggml_tensor * ggml_solve_tri(
  4959. struct ggml_context * ctx,
  4960. struct ggml_tensor * a,
  4961. struct ggml_tensor * b,
  4962. bool left,
  4963. bool lower,
  4964. bool uni) {
  4965. GGML_ASSERT(a->type == GGML_TYPE_F32);
  4966. GGML_ASSERT(b->type == GGML_TYPE_F32);
  4967. // A must be square and lower diagonal
  4968. GGML_ASSERT(a->ne[0] == a->ne[1]);
  4969. // B must have same outer dimension as A
  4970. GGML_ASSERT(a->ne[1] == b->ne[1]);
  4971. // batch dimensions must be equal
  4972. GGML_ASSERT(a->ne[2] == b->ne[2]);
  4973. GGML_ASSERT(a->ne[3] == b->ne[3]);
  4974. GGML_ASSERT(ggml_is_contiguous(a));
  4975. GGML_ASSERT(ggml_is_contiguous(b));
  4976. GGML_ASSERT(lower && left && !uni); // TODO: support other variants
  4977. struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, b->ne[0], b->ne[1], b->ne[2], b->ne[3]);
  4978. result->op = GGML_OP_SOLVE_TRI;
  4979. result->src[0] = a;
  4980. result->src[1] = b;
  4981. return result;
  4982. }
  4983. ////////////////////////////////////////////////////////////////////////////////
  4984. struct ggml_hash_set ggml_hash_set_new(size_t size) {
  4985. size = ggml_hash_size(size);
  4986. struct ggml_hash_set result;
  4987. result.size = size;
  4988. result.keys = GGML_MALLOC(sizeof(struct ggml_tensor *) * size);
  4989. result.used = GGML_CALLOC(ggml_bitset_size(size), sizeof(ggml_bitset_t));
  4990. return result;
  4991. }
  4992. void ggml_hash_set_reset(struct ggml_hash_set * hash_set) {
  4993. memset(hash_set->used, 0, sizeof(ggml_bitset_t) * ggml_bitset_size(hash_set->size));
  4994. }
  4995. void ggml_hash_set_free(struct ggml_hash_set * hash_set) {
  4996. GGML_FREE(hash_set->used);
  4997. GGML_FREE(hash_set->keys);
  4998. }
  4999. size_t ggml_hash_size(size_t min_sz) {
  5000. // next primes after powers of two
  5001. static const size_t primes[] = {
  5002. 2, 3, 5, 11, 17, 37, 67, 131, 257, 521, 1031,
  5003. 2053, 4099, 8209, 16411, 32771, 65537, 131101,
  5004. 262147, 524309, 1048583, 2097169, 4194319, 8388617,
  5005. 16777259, 33554467, 67108879, 134217757, 268435459,
  5006. 536870923, 1073741827, 2147483659
  5007. };
  5008. static const size_t n_primes = sizeof(primes)/sizeof(primes[0]);
  5009. // find the smallest prime that is larger or equal than min_sz
  5010. size_t l = 0;
  5011. size_t r = n_primes;
  5012. while (l < r) {
  5013. size_t m = (l + r)/2;
  5014. if (primes[m] < min_sz) {
  5015. l = m + 1;
  5016. } else {
  5017. r = m;
  5018. }
  5019. }
  5020. size_t sz = l < n_primes ? primes[l] : min_sz | 1;
  5021. return sz;
  5022. }
  5023. struct hash_map {
  5024. struct ggml_hash_set set;
  5025. struct ggml_tensor ** vals;
  5026. };
  5027. static struct hash_map * ggml_new_hash_map(size_t size) {
  5028. struct hash_map * result = GGML_MALLOC(sizeof(struct hash_map));
  5029. result->set = ggml_hash_set_new(size);
  5030. result->vals = GGML_CALLOC(result->set.size, sizeof(struct ggml_tensor *));
  5031. return result;
  5032. }
  5033. static void ggml_hash_map_free(struct hash_map * map) {
  5034. ggml_hash_set_free(&map->set);
  5035. GGML_FREE(map->vals);
  5036. GGML_FREE(map);
  5037. }
  5038. // utility functions to change gradients
  5039. // isrc is the index of tensor in cgraph->visited_has_set.keys
  5040. // the corresponding gradient (accumulators) are also at position isrc
  5041. // if tensor has a gradient accumulator, modify that accumulator in-place
  5042. // else if there is no gradient for tensor, set the corresponding value
  5043. // else, just add/subtract/etc. the gradients
  5044. static void ggml_add_or_set(
  5045. struct ggml_context * ctx,
  5046. struct ggml_cgraph * cgraph,
  5047. size_t isrc,
  5048. struct ggml_tensor * tensor) {
  5049. struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
  5050. GGML_ASSERT(src);
  5051. if (cgraph->grads[isrc]) {
  5052. cgraph->grads[isrc] = ggml_add_impl(ctx, cgraph->grads[isrc], tensor, /*inplace =*/ cgraph->grad_accs[isrc]);
  5053. } else {
  5054. cgraph->grads[isrc] = tensor;
  5055. }
  5056. ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name);
  5057. ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
  5058. }
  5059. static void ggml_acc_or_set(
  5060. struct ggml_context * ctx,
  5061. struct ggml_cgraph * cgraph,
  5062. size_t isrc,
  5063. struct ggml_tensor * tensor,
  5064. const size_t nb1,
  5065. const size_t nb2,
  5066. const size_t nb3,
  5067. const size_t offset) {
  5068. struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
  5069. GGML_ASSERT(src);
  5070. if (cgraph->grads[isrc]) {
  5071. cgraph->grads[isrc] = ggml_acc_impl(ctx, cgraph->grads[isrc], tensor, nb1, nb2, nb3, offset, cgraph->grad_accs[isrc]);
  5072. } else {
  5073. struct ggml_tensor * a_zero = ggml_scale(ctx, src, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN
  5074. cgraph->grads[isrc] = ggml_acc_impl(ctx, a_zero, tensor, nb1, nb2, nb3, offset, false);
  5075. }
  5076. ggml_format_name(cgraph->grads[isrc], "grad for %s", cgraph->visited_hash_set.keys[isrc]->name);
  5077. ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
  5078. }
  5079. static void ggml_add1_or_set(
  5080. struct ggml_context * ctx,
  5081. struct ggml_cgraph * cgraph,
  5082. size_t isrc,
  5083. struct ggml_tensor * tensor) {
  5084. struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
  5085. GGML_ASSERT(src);
  5086. if (cgraph->grads[isrc]) {
  5087. cgraph->grads[isrc] = ggml_add1_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
  5088. } else {
  5089. cgraph->grads[isrc] = ggml_repeat(ctx, tensor, src);
  5090. }
  5091. ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name);
  5092. ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
  5093. }
  5094. static void ggml_sub_or_set(
  5095. struct ggml_context * ctx,
  5096. struct ggml_cgraph * cgraph,
  5097. size_t isrc,
  5098. struct ggml_tensor * tensor) {
  5099. struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
  5100. GGML_ASSERT(src);
  5101. if (cgraph->grads[isrc]) {
  5102. cgraph->grads[isrc] = ggml_sub_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
  5103. } else {
  5104. cgraph->grads[isrc] = ggml_neg(ctx, tensor);
  5105. }
  5106. ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name);
  5107. ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
  5108. }
  5109. static void ggml_compute_backward(
  5110. struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, const bool * grads_needed) {
  5111. struct ggml_tensor * tensor = cgraph->nodes[i];
  5112. struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, tensor);
  5113. if (!grad) {
  5114. return;
  5115. }
  5116. struct ggml_tensor * src0 = tensor->src[0];
  5117. struct ggml_tensor * src1 = tensor->src[1];
  5118. struct ggml_tensor * src2 = tensor->src[2];
  5119. struct ggml_hash_set * hash_set = &cgraph->visited_hash_set;
  5120. const size_t isrc0 = src0 ? ggml_hash_find(hash_set, src0) : (size_t) -1;
  5121. const size_t isrc1 = src1 ? ggml_hash_find(hash_set, src1) : (size_t) -1;
  5122. const size_t isrc2 = src2 ? ggml_hash_find(hash_set, src2) : (size_t) -1;
  5123. const bool src0_needs_grads = src0 && isrc0 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc0) && grads_needed[isrc0];
  5124. const bool src1_needs_grads = src1 && isrc1 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc1) && grads_needed[isrc1];
  5125. const bool src2_needs_grads = src2 && isrc2 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc2) && grads_needed[isrc2];
  5126. switch (tensor->op) {
  5127. case GGML_OP_DUP: {
  5128. if (src0_needs_grads) {
  5129. ggml_add_or_set(ctx, cgraph, isrc0, grad);
  5130. }
  5131. } break;
  5132. case GGML_OP_ADD: {
  5133. if (src0_needs_grads) {
  5134. ggml_add_or_set(ctx, cgraph, isrc0, grad);
  5135. }
  5136. if (src1_needs_grads) {
  5137. struct ggml_tensor * tmp = grad;
  5138. if (!ggml_are_same_shape(src0, src1)) {
  5139. tmp = ggml_repeat_back(ctx, tmp, src1);
  5140. }
  5141. ggml_add_or_set(ctx, cgraph, isrc1, tmp);
  5142. }
  5143. } break;
  5144. case GGML_OP_ADD1: {
  5145. if (src0_needs_grads) {
  5146. ggml_add_or_set(ctx, cgraph, isrc0, grad);
  5147. }
  5148. if (src1_needs_grads) {
  5149. ggml_add_or_set(ctx, cgraph, isrc1, ggml_mean(ctx, grad)); // TODO: should probably be sum instead of mean
  5150. }
  5151. } break;
  5152. case GGML_OP_ACC: {
  5153. if (src0_needs_grads) {
  5154. ggml_add_or_set(ctx, cgraph, isrc0, grad);
  5155. }
  5156. if (src1_needs_grads) {
  5157. const size_t nb1 = ((int32_t *) tensor->op_params)[0];
  5158. const size_t nb2 = ((int32_t *) tensor->op_params)[1];
  5159. const size_t nb3 = ((int32_t *) tensor->op_params)[2];
  5160. const size_t offset = ((int32_t *) tensor->op_params)[3];
  5161. struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx,
  5162. grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
  5163. nb1, nb2, nb3, offset);
  5164. ggml_add_or_set(ctx, cgraph, isrc1, ggml_reshape(ctx, ggml_cont(ctx, tensor_grad_view), src1));
  5165. }
  5166. } break;
  5167. case GGML_OP_SUB: {
  5168. if (src0_needs_grads) {
  5169. ggml_add_or_set(ctx, cgraph, isrc0, grad);
  5170. }
  5171. if (src1_needs_grads) {
  5172. ggml_sub_or_set(ctx, cgraph, isrc1, grad);
  5173. }
  5174. } break;
  5175. case GGML_OP_MUL: {
  5176. if (src0_needs_grads) {
  5177. ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, src1));
  5178. }
  5179. if (src1_needs_grads) {
  5180. struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad);
  5181. if (!ggml_are_same_shape(src0, src1)) {
  5182. tmp = ggml_repeat_back(ctx, tmp, src1);
  5183. }
  5184. ggml_add_or_set(ctx, cgraph, isrc1, tmp);
  5185. }
  5186. } break;
  5187. case GGML_OP_DIV: {
  5188. if (src0_needs_grads) {
  5189. ggml_add_or_set(ctx, cgraph, isrc0, ggml_div(ctx, grad, src1));
  5190. }
  5191. if (src1_needs_grads) {
  5192. ggml_sub_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, grad, ggml_div(ctx, tensor, src1)));
  5193. }
  5194. } break;
  5195. case GGML_OP_SQR: {
  5196. if (src0_needs_grads) {
  5197. ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale(ctx, ggml_mul(ctx, src0, grad), 2.0f));
  5198. }
  5199. } break;
  5200. case GGML_OP_SQRT: {
  5201. if (src0_needs_grads) {
  5202. ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale(ctx, ggml_div(ctx, grad, tensor), 0.5f));
  5203. }
  5204. } break;
  5205. case GGML_OP_LOG: {
  5206. if (src0_needs_grads) {
  5207. ggml_add_or_set(ctx, cgraph, isrc0, ggml_div(ctx, grad, src0));
  5208. }
  5209. } break;
  5210. case GGML_OP_SIN: {
  5211. if (src0_needs_grads) {
  5212. ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_cos(ctx, src0)));
  5213. }
  5214. } break;
  5215. case GGML_OP_COS: {
  5216. if (src0_needs_grads) {
  5217. ggml_sub_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_sin(ctx, src0)));
  5218. }
  5219. } break;
  5220. case GGML_OP_SUM: {
  5221. if (src0_needs_grads) {
  5222. ggml_add1_or_set(ctx, cgraph, isrc0, grad);
  5223. }
  5224. } break;
  5225. case GGML_OP_SUM_ROWS: {
  5226. if (src0_needs_grads) {
  5227. ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat(ctx, grad, src0));
  5228. }
  5229. } break;
  5230. case GGML_OP_MEAN: {
  5231. if (src0_needs_grads) {
  5232. ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], 0.0, false));
  5233. }
  5234. } break;
  5235. case GGML_OP_REPEAT: {
  5236. if (src0_needs_grads) {
  5237. ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat_back(ctx, grad, src0));
  5238. }
  5239. } break;
  5240. case GGML_OP_REPEAT_BACK: {
  5241. if (src0_needs_grads) {
  5242. ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat(ctx, grad, src0));
  5243. }
  5244. } break;
  5245. case GGML_OP_RMS_NORM: {
  5246. if (src0_needs_grads) {
  5247. float eps;
  5248. memcpy(&eps, tensor->op_params, sizeof(float));
  5249. ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, grad, src0, eps));
  5250. }
  5251. } break;
  5252. case GGML_OP_MUL_MAT: {
  5253. // https://cs231n.github.io/optimization-2/#staged
  5254. // # forward pass
  5255. // s0 = np.random.randn(5, 10)
  5256. // s1 = np.random.randn(10, 3)
  5257. // t = s0.dot(s1)
  5258. // # now suppose we had the gradient on t from above in the circuit
  5259. // dt = np.random.randn(*t.shape) # same shape as t
  5260. // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix
  5261. // ds1 = t.T.dot(dt)
  5262. // tensor.shape [m,p,qq,rr]
  5263. // src0.shape [n,m,q1,r1]
  5264. // src1.shape [n,p,qq,rr]
  5265. if (src0_needs_grads) {
  5266. GGML_ASSERT(grad->ne[2] == src1->ne[2]);
  5267. GGML_ASSERT(grad->ne[3] == src1->ne[3]);
  5268. struct ggml_tensor * tmp =
  5269. ggml_out_prod(ctx, // [n,m,qq,rr]
  5270. src1, // [n,p,qq,rr]
  5271. grad); // [m,p,qq,rr]
  5272. if (!ggml_are_same_shape(tmp, src0)) {
  5273. GGML_ASSERT(tmp->ne[0] == src0->ne[0]);
  5274. GGML_ASSERT(tmp->ne[1] == src0->ne[1]);
  5275. GGML_ASSERT(tmp->ne[3] == 1);
  5276. const int64_t nr2 = tmp->ne[2] / src0->ne[2];
  5277. const size_t nb2 = tmp->nb[2] * nr2;
  5278. const size_t nb3 = tmp->nb[2];
  5279. tmp = ggml_view_4d(ctx, tmp, src0->ne[0], src0->ne[1], src0->ne[2], nr2, tmp->nb[1], nb2, nb3, 0);
  5280. tmp = ggml_repeat_back(ctx, tmp, src0);
  5281. }
  5282. ggml_add_or_set(ctx, cgraph, isrc0, tmp);
  5283. }
  5284. if (src1_needs_grads) {
  5285. ggml_add_or_set(ctx, cgraph, isrc1,
  5286. // ggml_mul_mat(ctx, // [n,p,qq,rr]
  5287. // ggml_cont(ctx, // [m,n,q1,r1]
  5288. // ggml_transpose(ctx, src0)), // [m,n,q1,r1]
  5289. // grad), // [m,p,qq,rr]
  5290. // when src0 is bigger than tensor->grad (this is mostly the case in llama),
  5291. // avoid transpose of src0, rather transpose smaller tensor->grad
  5292. // and then use ggml_out_prod
  5293. ggml_out_prod(ctx, // [n,p,qq,rr]
  5294. src0, // [n,m,q1,r1]
  5295. ggml_transpose(ctx, // [p,m,qq,rr]
  5296. grad))); // [m,p,qq,rr]
  5297. }
  5298. } break;
  5299. case GGML_OP_SCALE: {
  5300. if (src0_needs_grads) {
  5301. float s;
  5302. memcpy(&s, tensor->op_params, sizeof(float));
  5303. ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, 0.0, false));
  5304. }
  5305. } break;
  5306. case GGML_OP_SET: {
  5307. const size_t nb1 = ((const int32_t *) tensor->op_params)[0];
  5308. const size_t nb2 = ((const int32_t *) tensor->op_params)[1];
  5309. const size_t nb3 = ((const int32_t *) tensor->op_params)[2];
  5310. const size_t offset = ((const int32_t *) tensor->op_params)[3];
  5311. struct ggml_tensor * tensor_grad_view = NULL;
  5312. if (src0_needs_grads || src1_needs_grads) {
  5313. GGML_ASSERT(src0->type == tensor->type);
  5314. GGML_ASSERT(!cgraph->grads[isrc0] || cgraph->grads[isrc0]->type == grad->type);
  5315. GGML_ASSERT(!cgraph->grads[isrc1] || !src1_needs_grads || cgraph->grads[isrc1]->type == grad->type);
  5316. tensor_grad_view = ggml_view_4d(ctx,
  5317. grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
  5318. nb1, nb2, nb3, offset);
  5319. }
  5320. if (src0_needs_grads) {
  5321. struct ggml_tensor * tmp = ggml_neg(ctx, tensor_grad_view);
  5322. ggml_add_or_set(ctx, cgraph, isrc0, ggml_acc_impl(ctx, grad, tmp, nb1, nb2, nb3, offset, false));
  5323. }
  5324. if (src1_needs_grads) {
  5325. ggml_add_or_set(ctx, cgraph, isrc1, ggml_reshape(ctx, ggml_cont(ctx, tensor_grad_view), src1));
  5326. }
  5327. } break;
  5328. case GGML_OP_CPY: {
  5329. // cpy overwrites value of src1 by src0 and returns view(src1)
  5330. // the overwriting is mathematically equivalent to:
  5331. // tensor = src0 * 1 + src1 * 0
  5332. if (src0_needs_grads) {
  5333. // dsrc0 = dtensor * 1
  5334. ggml_add_or_set(ctx, cgraph, isrc0, ggml_reshape(ctx, grad, src0));
  5335. }
  5336. if (src1_needs_grads) {
  5337. // dsrc1 = dtensor * 0 -> noop
  5338. }
  5339. } break;
  5340. case GGML_OP_CONT: {
  5341. // same as cpy
  5342. if (src0_needs_grads) {
  5343. GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0]));
  5344. GGML_ASSERT(ggml_is_contiguous(grad));
  5345. GGML_ASSERT(ggml_nelements(tensor) == ggml_nelements(src0));
  5346. ggml_add_or_set(ctx, cgraph, isrc0,
  5347. ggml_are_same_shape(tensor, src0) ? grad : ggml_reshape(ctx, grad, src0));
  5348. }
  5349. } break;
  5350. case GGML_OP_RESHAPE: {
  5351. if (src0_needs_grads) {
  5352. struct ggml_tensor * grad_cont = ggml_is_contiguous(grad) ? grad : ggml_cont(ctx, grad);
  5353. ggml_add_or_set(ctx, cgraph, isrc0, ggml_reshape(ctx, grad_cont, src0));
  5354. }
  5355. } break;
  5356. case GGML_OP_VIEW: {
  5357. if (src0_needs_grads) {
  5358. size_t offset;
  5359. memcpy(&offset, tensor->op_params, sizeof(offset));
  5360. size_t nb1 = tensor->nb[1];
  5361. size_t nb2 = tensor->nb[2];
  5362. size_t nb3 = tensor->nb[3];
  5363. if (cgraph->grads[isrc0] && src0->type != cgraph->grads[isrc0]->type) {
  5364. // gradient is typically F32, but src0 could be other type
  5365. size_t ng = ggml_element_size(cgraph->grads[isrc0]);
  5366. size_t n0 = ggml_element_size(src0);
  5367. GGML_ASSERT(offset % n0 == 0);
  5368. GGML_ASSERT(nb1 % n0 == 0);
  5369. GGML_ASSERT(nb2 % n0 == 0);
  5370. GGML_ASSERT(nb3 % n0 == 0);
  5371. offset = (offset / n0) * ng;
  5372. nb1 = (nb1 / n0) * ng;
  5373. nb2 = (nb2 / n0) * ng;
  5374. nb3 = (nb3 / n0) * ng;
  5375. }
  5376. ggml_acc_or_set(ctx, cgraph, isrc0, grad, nb1, nb2, nb3, offset);
  5377. }
  5378. } break;
  5379. case GGML_OP_PERMUTE: {
  5380. if (src0_needs_grads) {
  5381. const int32_t * axes = (const int32_t *) tensor->op_params;
  5382. const int axis0 = axes[0] & 0x3;
  5383. const int axis1 = axes[1] & 0x3;
  5384. const int axis2 = axes[2] & 0x3;
  5385. const int axis3 = axes[3] & 0x3;
  5386. int axb[4] = {0,0,0,0}; // axes backward
  5387. axb[axis0] = 0;
  5388. axb[axis1] = 1;
  5389. axb[axis2] = 2;
  5390. axb[axis3] = 3;
  5391. ggml_add_or_set(ctx, cgraph, isrc0, ggml_permute(ctx, grad, axb[0], axb[1], axb[2], axb[3]));
  5392. }
  5393. } break;
  5394. case GGML_OP_TRANSPOSE: {
  5395. if (src0_needs_grads) {
  5396. ggml_add_or_set(ctx, cgraph, isrc0, ggml_transpose(ctx, grad));
  5397. }
  5398. } break;
  5399. case GGML_OP_GET_ROWS: {
  5400. if (src0_needs_grads) {
  5401. ggml_add_or_set(ctx, cgraph, isrc0, ggml_get_rows_back(ctx, grad, src1, src0));
  5402. }
  5403. if (src1_needs_grads) {
  5404. // noop
  5405. }
  5406. } break;
  5407. case GGML_OP_DIAG_MASK_INF: {
  5408. if (src0_needs_grads) {
  5409. /* ggml_diag_mask_inf_impl() shouldn't be here */
  5410. /* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */
  5411. const int n_past = ((const int32_t *) tensor->op_params)[0];
  5412. ggml_add_or_set(ctx, cgraph, isrc0, ggml_diag_mask_zero_impl(ctx, grad, n_past, false));
  5413. }
  5414. } break;
  5415. case GGML_OP_DIAG_MASK_ZERO: {
  5416. if (src0_needs_grads) {
  5417. const int n_past = ((const int32_t *) tensor->op_params)[0];
  5418. ggml_add_or_set(ctx, cgraph, isrc0, ggml_diag_mask_zero_impl(ctx, grad, n_past, false));
  5419. }
  5420. } break;
  5421. case GGML_OP_SOFT_MAX: {
  5422. if (src0_needs_grads) {
  5423. float scale = 1.0f;
  5424. float max_bias = 0.0f;
  5425. memcpy(&scale, (const float *) tensor->op_params + 0, sizeof(float));
  5426. memcpy(&max_bias, (const float *) tensor->op_params + 1, sizeof(float));
  5427. ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_ext_back(ctx, grad, tensor, scale, max_bias));
  5428. }
  5429. GGML_ASSERT((!src1 || !src1_needs_grads) && "backward pass for softmax mask not implemented");
  5430. } break;
  5431. case GGML_OP_ROPE: {
  5432. if (src0_needs_grads) {
  5433. //const int n_past = ((int32_t *) tensor->op_params)[0];
  5434. const int n_dims = ((const int32_t *) tensor->op_params)[1];
  5435. const int mode = ((const int32_t *) tensor->op_params)[2];
  5436. //const int n_ctx = ((int32_t *) tensor->op_params)[3];
  5437. const int n_ctx_orig = ((const int32_t *) tensor->op_params)[4];
  5438. float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
  5439. int sections[4] = {0, 0, 0, 0};
  5440. memcpy(&freq_base, (const float *) tensor->op_params + 5, sizeof(float));
  5441. memcpy(&freq_scale, (const float *) tensor->op_params + 6, sizeof(float));
  5442. memcpy(&ext_factor, (const float *) tensor->op_params + 7, sizeof(float));
  5443. memcpy(&attn_factor, (const float *) tensor->op_params + 8, sizeof(float));
  5444. memcpy(&beta_fast, (const float *) tensor->op_params + 9, sizeof(float));
  5445. memcpy(&beta_slow, (const float *) tensor->op_params + 10, sizeof(float));
  5446. memcpy(&sections, tensor->op_params + 11, sizeof(sections));
  5447. struct ggml_tensor * rope_back = grad->ne[2] == src1->ne[0] ?
  5448. ggml_rope_ext_back(ctx, grad, src1, src2, n_dims,
  5449. mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow) :
  5450. ggml_rope_multi_back(ctx, grad, src1, src2, n_dims, sections,
  5451. mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
  5452. ggml_add_or_set(ctx, cgraph, isrc0, rope_back);
  5453. }
  5454. GGML_ASSERT((!src2 || !src2_needs_grads) && "gradients for freq factors not implemented");
  5455. } break;
  5456. case GGML_OP_IM2COL: {
  5457. if (src1_needs_grads) {
  5458. const int32_t s0 = ggml_get_op_params_i32(tensor, 0);
  5459. const int32_t s1 = ggml_get_op_params_i32(tensor, 1);
  5460. const int32_t p0 = ggml_get_op_params_i32(tensor, 2);
  5461. const int32_t p1 = ggml_get_op_params_i32(tensor, 3);
  5462. const int32_t d0 = ggml_get_op_params_i32(tensor, 4);
  5463. const int32_t d1 = ggml_get_op_params_i32(tensor, 5);
  5464. const bool is_2D = ggml_get_op_params_i32(tensor, 6) == 1;
  5465. ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, grad, src0, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));
  5466. }
  5467. } break;
  5468. case GGML_OP_POOL_2D: {
  5469. if (src0_needs_grads) {
  5470. const enum ggml_op_pool op = ggml_get_op_params_i32(tensor, 0);
  5471. const int32_t k0 = ggml_get_op_params_i32(tensor, 1);
  5472. const int32_t k1 = ggml_get_op_params_i32(tensor, 2);
  5473. const int32_t s0 = ggml_get_op_params_i32(tensor, 3);
  5474. const int32_t s1 = ggml_get_op_params_i32(tensor, 4);
  5475. const int32_t p0 = ggml_get_op_params_i32(tensor, 5);
  5476. const int32_t p1 = ggml_get_op_params_i32(tensor, 6);
  5477. ggml_add_or_set(ctx, cgraph, isrc0, ggml_pool_2d_back(ctx, grad, src0, op, k0, k1, s0, s1, p0, p1));
  5478. }
  5479. } break;
  5480. case GGML_OP_WIN_PART:
  5481. case GGML_OP_WIN_UNPART:
  5482. case GGML_OP_UNARY: {
  5483. switch (ggml_get_unary_op(tensor)) {
  5484. case GGML_UNARY_OP_ABS: {
  5485. if (src0_needs_grads) {
  5486. ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, ggml_sgn(ctx, src0), grad));
  5487. }
  5488. } break;
  5489. case GGML_UNARY_OP_SGN: {
  5490. // noop
  5491. } break;
  5492. case GGML_UNARY_OP_NEG: {
  5493. if (src0_needs_grads) {
  5494. ggml_sub_or_set(ctx, cgraph, isrc0, grad);
  5495. }
  5496. } break;
  5497. case GGML_UNARY_OP_STEP: {
  5498. // noop
  5499. } break;
  5500. case GGML_UNARY_OP_RELU: {
  5501. if (src0_needs_grads) {
  5502. ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, ggml_step(ctx, src0), grad));
  5503. }
  5504. } break;
  5505. case GGML_UNARY_OP_SILU: {
  5506. if (src0_needs_grads) {
  5507. ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, grad, src0));
  5508. }
  5509. } break;
  5510. case GGML_UNARY_OP_EXP: {
  5511. if (src0_needs_grads) {
  5512. ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, tensor, grad));
  5513. }
  5514. } break;
  5515. case GGML_UNARY_OP_EXPM1: {
  5516. if (src0_needs_grads) {
  5517. ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_exp(ctx, src0)));
  5518. }
  5519. } break;
  5520. case GGML_UNARY_OP_SOFTPLUS: {
  5521. if (src0_needs_grads) {
  5522. ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_sigmoid(ctx, src0)));
  5523. }
  5524. } break;
  5525. default: {
  5526. fprintf(stderr, "%s: unsupported unary op for backward pass: %s\n",
  5527. __func__, ggml_unary_op_name(ggml_get_unary_op(tensor)));
  5528. GGML_ABORT("fatal error");
  5529. } //break;
  5530. }
  5531. } break;
  5532. case GGML_OP_CROSS_ENTROPY_LOSS: {
  5533. if (src0_needs_grads) {
  5534. ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, grad, src0, src1));
  5535. }
  5536. GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
  5537. } break;
  5538. case GGML_OP_GLU: {
  5539. switch (ggml_get_glu_op(tensor)) {
  5540. case GGML_GLU_OP_SWIGLU: {
  5541. if (src0_needs_grads) {
  5542. GGML_ASSERT(src1 && "backward pass only implemented for split swiglu");
  5543. ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, ggml_mul(ctx, grad, src1), src0));
  5544. }
  5545. if (src1_needs_grads) {
  5546. ggml_add_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, ggml_silu(ctx, src0), grad));
  5547. }
  5548. } break;
  5549. default: {
  5550. GGML_ABORT("unsupported glu op for backward pass: %s", ggml_glu_op_name(ggml_get_glu_op(tensor)));
  5551. } //break;
  5552. }
  5553. } break;
  5554. case GGML_OP_NONE: {
  5555. // noop
  5556. } break;
  5557. case GGML_OP_COUNT:
  5558. default: {
  5559. GGML_ABORT("%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
  5560. } //break;
  5561. }
  5562. GGML_ASSERT(!src0_needs_grads || ggml_are_same_shape(src0, cgraph->grads[isrc0]));
  5563. GGML_ASSERT(!src1_needs_grads || ggml_are_same_shape(src1, cgraph->grads[isrc1]));
  5564. GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2]));
  5565. }
  5566. static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
  5567. // check if already visited
  5568. size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
  5569. GGML_ASSERT(node_hash_pos != GGML_HASHSET_FULL);
  5570. if (!ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) {
  5571. // This is the first time we see this node in the current graph.
  5572. cgraph->visited_hash_set.keys[node_hash_pos] = node;
  5573. ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos);
  5574. cgraph->use_counts[node_hash_pos] = 0;
  5575. } else {
  5576. // already visited
  5577. return node_hash_pos;
  5578. }
  5579. for (int i = 0; i < GGML_MAX_SRC; ++i) {
  5580. const int k =
  5581. (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
  5582. (cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) :
  5583. /* unknown order, just fall back to using i */ i;
  5584. struct ggml_tensor * src = node->src[k];
  5585. if (src) {
  5586. size_t src_hash_pos = ggml_visit_parents(cgraph, src);
  5587. // Update the use count for this operand.
  5588. cgraph->use_counts[src_hash_pos]++;
  5589. }
  5590. }
  5591. if (node->op == GGML_OP_NONE && !(node->flags & GGML_TENSOR_FLAG_PARAM)) {
  5592. // reached a leaf node, not part of the gradient graph (e.g. a constant)
  5593. GGML_ASSERT(cgraph->n_leafs < cgraph->size);
  5594. if (strlen(node->name) == 0) {
  5595. ggml_format_name(node, "leaf_%d", cgraph->n_leafs);
  5596. }
  5597. cgraph->leafs[cgraph->n_leafs] = node;
  5598. cgraph->n_leafs++;
  5599. } else {
  5600. GGML_ASSERT(cgraph->n_nodes < cgraph->size);
  5601. if (strlen(node->name) == 0) {
  5602. ggml_format_name(node, "node_%d", cgraph->n_nodes);
  5603. }
  5604. cgraph->nodes[cgraph->n_nodes] = node;
  5605. cgraph->n_nodes++;
  5606. }
  5607. return node_hash_pos;
  5608. }
  5609. static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) {
  5610. if (!expand) {
  5611. // TODO: this branch isn't accessible anymore, maybe move this to ggml_build_forward_expand
  5612. ggml_graph_clear(cgraph);
  5613. }
  5614. const int n0 = cgraph->n_nodes;
  5615. ggml_visit_parents(cgraph, tensor);
  5616. const int n_new = cgraph->n_nodes - n0;
  5617. GGML_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new);
  5618. if (n_new > 0) {
  5619. // the last added node should always be starting point
  5620. GGML_ASSERT(cgraph->nodes[cgraph->n_nodes - 1] == tensor);
  5621. }
  5622. }
  5623. void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) {
  5624. ggml_build_forward_impl(cgraph, tensor, true);
  5625. }
  5626. void ggml_build_backward_expand(
  5627. struct ggml_context * ctx,
  5628. struct ggml_cgraph * cgraph,
  5629. struct ggml_tensor ** grad_accs) {
  5630. GGML_ASSERT(cgraph->n_nodes > 0);
  5631. GGML_ASSERT(cgraph->grads);
  5632. GGML_ASSERT(cgraph->grad_accs);
  5633. const int n_nodes_f = cgraph->n_nodes;
  5634. memset(cgraph->grads, 0, cgraph->visited_hash_set.size*sizeof(struct ggml_tensor *));
  5635. memset(cgraph->grad_accs, 0, cgraph->visited_hash_set.size*sizeof(struct ggml_tensor *));
  5636. bool * grads_needed = calloc(cgraph->visited_hash_set.size, sizeof(bool));
  5637. {
  5638. bool any_params = false;
  5639. bool any_loss = false;
  5640. for (int i = 0; i < n_nodes_f; ++i) {
  5641. struct ggml_tensor * node = cgraph->nodes[i];
  5642. any_params = any_params || (node->flags & GGML_TENSOR_FLAG_PARAM);
  5643. any_loss = any_loss || (node->flags & GGML_TENSOR_FLAG_LOSS);
  5644. }
  5645. GGML_ASSERT(any_params && "no trainable parameters found, did you forget to call ggml_set_param?");
  5646. GGML_ASSERT(any_loss && "no training loss found, did you forget to call ggml_set_loss?");
  5647. }
  5648. for (int i = 0; i < n_nodes_f; ++i) {
  5649. struct ggml_tensor * node = cgraph->nodes[i];
  5650. if (node->type == GGML_TYPE_I32) {
  5651. continue;
  5652. }
  5653. bool node_needs_grad = (node->flags & GGML_TENSOR_FLAG_PARAM) || (node->flags & GGML_TENSOR_FLAG_LOSS);
  5654. bool ignore_src[GGML_MAX_SRC] = {false};
  5655. switch (node->op) {
  5656. // gradients in node->src[0] for one reason or another have no effect on output gradients
  5657. case GGML_OP_IM2COL: // only used for its shape
  5658. case GGML_OP_IM2COL_BACK: // same as IM2COL
  5659. ignore_src[0] = true;
  5660. break;
  5661. case GGML_OP_UNARY: {
  5662. const enum ggml_unary_op uop = ggml_get_unary_op(node);
  5663. // SGN and STEP unary ops are piecewise constant
  5664. if (uop == GGML_UNARY_OP_SGN || uop == GGML_UNARY_OP_STEP) {
  5665. ignore_src[0] = true;
  5666. }
  5667. } break;
  5668. // gradients in node->src[1] for one reason or another have no effect on output gradients
  5669. case GGML_OP_CPY: // gradients in CPY target are irrelevant
  5670. case GGML_OP_GET_ROWS: // row indices not differentiable
  5671. case GGML_OP_GET_ROWS_BACK: // same as for GET_ROWS
  5672. case GGML_OP_ROPE: // positions not differentiable
  5673. ignore_src[1] = true;
  5674. break;
  5675. default:
  5676. break;
  5677. }
  5678. for (int j = 0; j < GGML_MAX_SRC; ++j) {
  5679. if (!node->src[j] || ignore_src[j] || !grads_needed[ggml_hash_find(&cgraph->visited_hash_set, node->src[j])]) {
  5680. continue;
  5681. }
  5682. GGML_ASSERT(node->src[j]->type == GGML_TYPE_F32 || node->src[j]->type == GGML_TYPE_F16);
  5683. node_needs_grad = true;
  5684. break;
  5685. }
  5686. if (!node_needs_grad) {
  5687. continue;
  5688. }
  5689. // inplace operations are currently not supported
  5690. GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW ||
  5691. node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE);
  5692. const size_t ihash = ggml_hash_find(&cgraph->visited_hash_set, node);
  5693. GGML_ASSERT(ihash != GGML_HASHSET_FULL);
  5694. GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, ihash));
  5695. if (grad_accs && grad_accs[i]) {
  5696. cgraph->grad_accs[ihash] = grad_accs[i];
  5697. cgraph->grads[ihash] = cgraph->grad_accs[ihash];
  5698. } else if (node->flags & GGML_TENSOR_FLAG_LOSS) {
  5699. // loss tensors always need a gradient accumulator
  5700. cgraph->grad_accs[ihash] = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
  5701. cgraph->grads[ihash] = cgraph->grad_accs[ihash];
  5702. }
  5703. grads_needed[ihash] = true;
  5704. }
  5705. for (int i = n_nodes_f - 1; i >= 0; --i) {
  5706. // inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation
  5707. // use allocator to automatically make inplace operations
  5708. ggml_compute_backward(ctx, cgraph, i, grads_needed);
  5709. }
  5710. free(grads_needed);
  5711. }
  5712. static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
  5713. void * ptr = *p;
  5714. ptr = (void *) GGML_PAD((uintptr_t) ptr, align);
  5715. *p = (void *) ((char *) ptr + size);
  5716. return ptr;
  5717. }
  5718. static size_t ggml_graph_nbytes(size_t size, bool grads) {
  5719. size_t hash_size = ggml_hash_size(size * 2);
  5720. void * p = 0;
  5721. incr_ptr_aligned(&p, sizeof(struct ggml_cgraph), 1);
  5722. incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // nodes
  5723. incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs
  5724. incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t)); // use_counts
  5725. incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // hash keys
  5726. if (grads) {
  5727. incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads
  5728. incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grad_accs
  5729. }
  5730. incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t));
  5731. size_t nbytes = (size_t) p;
  5732. return nbytes;
  5733. }
  5734. size_t ggml_graph_overhead_custom(size_t size, bool grads) {
  5735. return GGML_OBJECT_SIZE + GGML_PAD(ggml_graph_nbytes(size, grads), GGML_MEM_ALIGN);
  5736. }
  5737. size_t ggml_graph_overhead(void) {
  5738. return ggml_graph_overhead_custom(GGML_DEFAULT_GRAPH_SIZE, false);
  5739. }
  5740. struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads) {
  5741. const size_t obj_size = ggml_graph_nbytes(size, grads);
  5742. struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_TYPE_GRAPH, obj_size);
  5743. struct ggml_cgraph * cgraph = (struct ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs);
  5744. // the size of the hash table is doubled since it needs to hold both nodes and leafs
  5745. size_t hash_size = ggml_hash_size(size * 2);
  5746. void * p = cgraph + 1;
  5747. struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
  5748. struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
  5749. int32_t * use_counts_ptr = incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t));
  5750. struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
  5751. struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
  5752. struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
  5753. ggml_bitset_t * hash_used = incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t));
  5754. // check that we allocated the correct amount of memory
  5755. assert(obj_size == (size_t)((char *)p - (char *)cgraph));
  5756. *cgraph = (struct ggml_cgraph) {
  5757. /*.size =*/ size,
  5758. /*.n_nodes =*/ 0,
  5759. /*.n_leafs =*/ 0,
  5760. /*.nodes =*/ nodes_ptr,
  5761. /*.grads =*/ grads_ptr,
  5762. /*.grad_accs =*/ grad_accs_ptr,
  5763. /*.leafs =*/ leafs_ptr,
  5764. /*.use_counts =*/ use_counts_ptr,
  5765. /*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr },
  5766. /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
  5767. };
  5768. ggml_hash_set_reset(&cgraph->visited_hash_set);
  5769. if (grads) {
  5770. memset(cgraph->grads, 0, hash_size*sizeof(struct ggml_tensor *));
  5771. memset(cgraph->grad_accs, 0, hash_size*sizeof(struct ggml_tensor *));
  5772. }
  5773. return cgraph;
  5774. }
  5775. struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) {
  5776. return ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, false);
  5777. }
  5778. struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1) {
  5779. struct ggml_cgraph cgraph = {
  5780. /*.size =*/ 0,
  5781. /*.n_nodes =*/ i1 - i0,
  5782. /*.n_leafs =*/ 0,
  5783. /*.nodes =*/ cgraph0->nodes + i0,
  5784. /*.grads =*/ NULL, // gradients would need visited_hash_set
  5785. /*.grad_accs =*/ NULL,
  5786. /*.leafs =*/ NULL,
  5787. /*.use_counts =*/ cgraph0->use_counts,
  5788. /*.visited_hash_set =*/ cgraph0->visited_hash_set,
  5789. /*.order =*/ cgraph0->order,
  5790. };
  5791. return cgraph;
  5792. }
  5793. void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
  5794. GGML_ASSERT(dst->size >= src->n_leafs);
  5795. GGML_ASSERT(dst->size >= src->n_nodes);
  5796. GGML_ASSERT(dst->visited_hash_set.size >= src->visited_hash_set.size);
  5797. dst->n_leafs = src->n_leafs;
  5798. dst->n_nodes = src->n_nodes;
  5799. dst->order = src->order;
  5800. for (int i = 0; i < src->n_leafs; ++i) {
  5801. dst->leafs[i] = src->leafs[i];
  5802. }
  5803. for (int i = 0; i < src->n_nodes; ++i) {
  5804. dst->nodes[i] = src->nodes[i];
  5805. }
  5806. for (size_t i = 0; i < src->visited_hash_set.size; ++i) {
  5807. // copy all hashset keys (tensors) that are in use
  5808. if (ggml_bitset_get(src->visited_hash_set.used, i)) {
  5809. size_t new_hash_pos = ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
  5810. dst->use_counts[new_hash_pos] = src->use_counts[i];
  5811. }
  5812. }
  5813. if (dst->grads) {
  5814. memset(dst->grads, 0, dst->visited_hash_set.size*sizeof(struct ggml_tensor *));
  5815. memset(dst->grad_accs, 0, dst->visited_hash_set.size*sizeof(struct ggml_tensor *));
  5816. }
  5817. if (src->grads) {
  5818. GGML_ASSERT(dst->grads != NULL);
  5819. GGML_ASSERT(dst->grad_accs != NULL);
  5820. for (int i = 0; i < src->n_nodes; ++i) {
  5821. const size_t igrad_src = ggml_hash_find(&src->visited_hash_set, src->nodes[i]);
  5822. const size_t igrad_dst = ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]);
  5823. GGML_ASSERT(igrad_src != GGML_HASHSET_FULL);
  5824. GGML_ASSERT(ggml_bitset_get(src->visited_hash_set.used, igrad_src));
  5825. GGML_ASSERT(igrad_dst != GGML_HASHSET_FULL);
  5826. GGML_ASSERT(ggml_bitset_get(dst->visited_hash_set.used, igrad_dst));
  5827. dst->grads[igrad_dst] = src->grads[igrad_src];
  5828. dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src];
  5829. }
  5830. }
  5831. }
  5832. struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads) {
  5833. struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads || force_grads);
  5834. ggml_graph_cpy(cgraph, result);
  5835. return result;
  5836. }
  5837. struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
  5838. if (ggml_is_empty(tensor)) {
  5839. return tensor;
  5840. }
  5841. if (tensor->buffer) {
  5842. ggml_backend_tensor_memset(tensor, 0, 0, ggml_nbytes(tensor));
  5843. } else {
  5844. GGML_ASSERT(tensor->data);
  5845. memset(tensor->data, 0, ggml_nbytes(tensor));
  5846. }
  5847. return tensor;
  5848. }
  5849. void ggml_graph_reset(struct ggml_cgraph * cgraph) {
  5850. if (!cgraph) {
  5851. return;
  5852. }
  5853. GGML_ASSERT(cgraph->grads != NULL);
  5854. for (int i = 0; i < cgraph->n_nodes; i++) {
  5855. struct ggml_tensor * node = cgraph->nodes[i];
  5856. struct ggml_tensor * grad_acc = ggml_graph_get_grad_acc(cgraph, node);
  5857. if (node->op == GGML_OP_OPT_STEP_ADAMW) {
  5858. // clear momenta
  5859. ggml_set_zero(node->src[2]);
  5860. ggml_set_zero(node->src[3]);
  5861. }
  5862. // initial gradients of loss should be 1, 0 otherwise
  5863. if (grad_acc) {
  5864. if (node->flags & GGML_TENSOR_FLAG_LOSS) {
  5865. GGML_ASSERT(grad_acc->type == GGML_TYPE_F32);
  5866. GGML_ASSERT(ggml_is_scalar(grad_acc));
  5867. const float onef = 1.0f;
  5868. if (grad_acc->buffer) {
  5869. ggml_backend_tensor_set(grad_acc, &onef, 0, sizeof(float));
  5870. } else {
  5871. GGML_ASSERT(grad_acc->data);
  5872. *((float *) grad_acc->data) = onef;
  5873. }
  5874. } else {
  5875. ggml_set_zero(grad_acc);
  5876. }
  5877. }
  5878. }
  5879. }
  5880. void ggml_graph_clear(struct ggml_cgraph * cgraph) {
  5881. cgraph->n_leafs = 0;
  5882. cgraph->n_nodes = 0;
  5883. ggml_hash_set_reset(&cgraph->visited_hash_set);
  5884. }
  5885. int ggml_graph_size(struct ggml_cgraph * cgraph) {
  5886. return cgraph->size;
  5887. }
  5888. struct ggml_tensor * ggml_graph_node(struct ggml_cgraph * cgraph, int i) {
  5889. if (i < 0) {
  5890. GGML_ASSERT(cgraph->n_nodes + i >= 0);
  5891. return cgraph->nodes[cgraph->n_nodes + i];
  5892. }
  5893. GGML_ASSERT(i < cgraph->n_nodes);
  5894. return cgraph->nodes[i];
  5895. }
  5896. struct ggml_tensor ** ggml_graph_nodes(struct ggml_cgraph * cgraph) {
  5897. return cgraph->nodes;
  5898. }
  5899. int ggml_graph_n_nodes(struct ggml_cgraph * cgraph) {
  5900. return cgraph->n_nodes;
  5901. }
  5902. void ggml_graph_add_node(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) {
  5903. GGML_ASSERT(cgraph->size > cgraph->n_nodes);
  5904. cgraph->nodes[cgraph->n_nodes] = tensor;
  5905. cgraph->n_nodes++;
  5906. }
  5907. struct ggml_tensor * ggml_graph_get_tensor(const struct ggml_cgraph * cgraph, const char * name) {
  5908. for (int i = 0; i < cgraph->n_leafs; i++) {
  5909. struct ggml_tensor * leaf = cgraph->leafs[i];
  5910. if (strcmp(leaf->name, name) == 0) {
  5911. return leaf;
  5912. }
  5913. }
  5914. for (int i = 0; i < cgraph->n_nodes; i++) {
  5915. struct ggml_tensor * node = cgraph->nodes[i];
  5916. if (strcmp(node->name, name) == 0) {
  5917. return node;
  5918. }
  5919. }
  5920. return NULL;
  5921. }
  5922. struct ggml_tensor * ggml_graph_get_grad(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
  5923. const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
  5924. return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) && cgraph->grads ? cgraph->grads[igrad] : NULL;
  5925. }
  5926. struct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
  5927. const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
  5928. return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) && cgraph->grad_accs ? cgraph->grad_accs[igrad] : NULL;
  5929. }
  5930. void ggml_graph_print(const struct ggml_cgraph * cgraph) {
  5931. GGML_LOG_INFO("=== GRAPH ===\n");
  5932. GGML_LOG_INFO("n_nodes = %d\n", cgraph->n_nodes);
  5933. for (int i = 0; i < cgraph->n_nodes; i++) {
  5934. struct ggml_tensor * node = cgraph->nodes[i];
  5935. GGML_LOG_INFO(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s\n",
  5936. i,
  5937. node->ne[0], node->ne[1], node->ne[2],
  5938. ggml_op_name(node->op), (node->flags & GGML_TENSOR_FLAG_PARAM) ? "x" :
  5939. ggml_graph_get_grad(cgraph, node) ? "g" : " ");
  5940. }
  5941. GGML_LOG_INFO("n_leafs = %d\n", cgraph->n_leafs);
  5942. for (int i = 0; i < cgraph->n_leafs; i++) {
  5943. struct ggml_tensor * node = cgraph->leafs[i];
  5944. GGML_LOG_INFO(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s %16s\n",
  5945. i,
  5946. node->ne[0], node->ne[1],
  5947. ggml_op_name(node->op),
  5948. ggml_get_name(node));
  5949. }
  5950. GGML_LOG_INFO("========================================\n");
  5951. }
  5952. static int ggml_node_list_find_tensor(const struct ggml_cgraph * cgraph,
  5953. const int * idxs,
  5954. int count,
  5955. const struct ggml_tensor * tensor) {
  5956. GGML_ASSERT(cgraph && idxs);
  5957. for (int i = 0; i < count; ++i) {
  5958. const int node_idx = idxs[i];
  5959. if (node_idx >= cgraph->n_nodes) {
  5960. return -1;
  5961. }
  5962. if (cgraph->nodes[node_idx] == tensor) {
  5963. return i;
  5964. }
  5965. }
  5966. return -1;
  5967. }
  5968. bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
  5969. const int * node_idxs,
  5970. int count,
  5971. const enum ggml_op * ops,
  5972. const int * outputs,
  5973. int num_outputs) {
  5974. GGML_ASSERT(outputs && num_outputs > 0);
  5975. for (int i = 0; i < count; ++i) {
  5976. if (node_idxs[i] >= cgraph->n_nodes) {
  5977. return false;
  5978. }
  5979. const struct ggml_tensor * node = cgraph->nodes[node_idxs[i]];
  5980. if (node->op != ops[i]) {
  5981. return false;
  5982. }
  5983. if (ggml_node_list_find_tensor(cgraph, outputs, num_outputs, node) != -1) {
  5984. continue;
  5985. }
  5986. if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {
  5987. return false;
  5988. }
  5989. int subgraph_uses = 0;
  5990. for (int j = i + 1; j < count; ++j) {
  5991. const struct ggml_tensor * other_node = cgraph->nodes[node_idxs[j]];
  5992. for (int src_idx = 0; src_idx < GGML_MAX_SRC; src_idx++) {
  5993. if (other_node->src[src_idx] == node) {
  5994. subgraph_uses++;
  5995. }
  5996. }
  5997. }
  5998. if (subgraph_uses != ggml_node_get_use_count(cgraph, node_idxs[i])) {
  5999. return false;
  6000. }
  6001. // if node is a view, check if the view_src and all it's parent view_srcs are within the subgraph
  6002. struct ggml_tensor * view_src = node->view_src;
  6003. while (view_src) {
  6004. if (ggml_node_list_find_tensor(cgraph, node_idxs, count, view_src) == -1) {
  6005. return false;
  6006. }
  6007. view_src = view_src->view_src;
  6008. }
  6009. }
  6010. return true;
  6011. }
  6012. // check if node is part of the graph
  6013. static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
  6014. if (cgraph == NULL) {
  6015. return true;
  6016. }
  6017. for (int i = 0; i < cgraph->n_nodes; i++) {
  6018. if (cgraph->nodes[i] == node) {
  6019. return true;
  6020. }
  6021. }
  6022. return false;
  6023. }
  6024. static struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
  6025. for (int i = 0; i < cgraph->n_nodes; i++) {
  6026. struct ggml_tensor * parent = cgraph->nodes[i];
  6027. struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, parent);
  6028. if (grad == node) {
  6029. return parent;
  6030. }
  6031. }
  6032. return NULL;
  6033. }
  6034. static void ggml_graph_dump_dot_node_edge(FILE * fp, const struct ggml_cgraph * gb, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label) {
  6035. struct ggml_tensor * gparent = ggml_graph_get_parent(gb, node);
  6036. struct ggml_tensor * gparent0 = ggml_graph_get_parent(gb, parent);
  6037. fprintf(fp, " \"%p\" -> \"%p\" [ arrowhead = %s; style = %s; label = \"%s\"; ]\n",
  6038. gparent0 ? (void *) gparent0 : (void *) parent,
  6039. gparent ? (void *) gparent : (void *) node,
  6040. gparent ? "empty" : "vee",
  6041. gparent ? "dashed" : "solid",
  6042. label);
  6043. }
  6044. static void ggml_graph_dump_dot_leaf_edge(FILE * fp, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label) {
  6045. fprintf(fp, " \"%p\" -> \"%p\" [ label = \"%s\"; ]\n",
  6046. (void *) parent,
  6047. (void *) node,
  6048. label);
  6049. }
  6050. void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename) {
  6051. char color[16];
  6052. FILE * fp = ggml_fopen(filename, "w");
  6053. GGML_ASSERT(fp);
  6054. fprintf(fp, "digraph G {\n");
  6055. fprintf(fp, " newrank = true;\n");
  6056. fprintf(fp, " rankdir = TB;\n");
  6057. for (int i = 0; i < gb->n_nodes; i++) {
  6058. struct ggml_tensor * node = gb->nodes[i];
  6059. struct ggml_tensor * grad = ggml_graph_get_grad(gb, node);
  6060. if (ggml_graph_get_parent(gb, node) != NULL) {
  6061. continue;
  6062. }
  6063. if (node->flags & GGML_TENSOR_FLAG_PARAM) {
  6064. snprintf(color, sizeof(color), "yellow");
  6065. } else if (grad) {
  6066. if (ggml_graph_find(gf, node)) {
  6067. snprintf(color, sizeof(color), "green");
  6068. } else {
  6069. snprintf(color, sizeof(color), "lightblue");
  6070. }
  6071. } else {
  6072. snprintf(color, sizeof(color), "white");
  6073. }
  6074. fprintf(fp, " \"%p\" [ "
  6075. "style = filled; fillcolor = %s; shape = record; "
  6076. "label=\"",
  6077. (void *) node, color);
  6078. if (strlen(node->name) > 0) {
  6079. fprintf(fp, "%s (%s)|", node->name, ggml_type_name(node->type));
  6080. } else {
  6081. fprintf(fp, "(%s)|", ggml_type_name(node->type));
  6082. }
  6083. if (ggml_is_matrix(node)) {
  6084. fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | <x>%s", i, node->ne[0], node->ne[1], ggml_op_symbol(node->op));
  6085. } else {
  6086. fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | <x>%s", i, node->ne[0], node->ne[1], node->ne[2], ggml_op_symbol(node->op));
  6087. }
  6088. if (grad) {
  6089. fprintf(fp, " | <g>%s\"; ]\n", ggml_op_symbol(grad->op));
  6090. } else {
  6091. fprintf(fp, "\"; ]\n");
  6092. }
  6093. }
  6094. for (int i = 0; i < gb->n_leafs; i++) {
  6095. struct ggml_tensor * node = gb->leafs[i];
  6096. snprintf(color, sizeof(color), "pink");
  6097. fprintf(fp, " \"%p\" [ "
  6098. "style = filled; fillcolor = %s; shape = record; "
  6099. "label=\"<x>",
  6100. (void *) node, color);
  6101. if (strlen(node->name) > 0) {
  6102. fprintf(fp, "%s (%s)|", node->name, ggml_type_name(node->type));
  6103. } else {
  6104. fprintf(fp, "(%s)|", ggml_type_name(node->type));
  6105. }
  6106. fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]);
  6107. if (ggml_nelements(node) < 5 && node->data != NULL) {
  6108. fprintf(fp, " | (");
  6109. for (int j = 0; j < ggml_nelements(node); j++) {
  6110. // FIXME: use ggml-backend to obtain the tensor data
  6111. //if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) {
  6112. // fprintf(fp, "%d", ggml_get_i32_1d(node, j));
  6113. //}
  6114. //else if (node->type == GGML_TYPE_F32 ||
  6115. // node->type == GGML_TYPE_F16 ||
  6116. // node->type == GGML_TYPE_BF16) {
  6117. // fprintf(fp, "%.1e", (double)ggml_get_f32_1d(node, j));
  6118. //}
  6119. //else
  6120. {
  6121. fprintf(fp, "#");
  6122. }
  6123. if (j < ggml_nelements(node) - 1) {
  6124. fprintf(fp, ", ");
  6125. }
  6126. }
  6127. fprintf(fp, ")");
  6128. }
  6129. fprintf(fp, "\"; ]\n");
  6130. }
  6131. for (int i = 0; i < gb->n_nodes; i++) {
  6132. struct ggml_tensor * node = gb->nodes[i];
  6133. for (int j = 0; j < GGML_MAX_SRC; j++) {
  6134. if (node->src[j]) {
  6135. char label[16];
  6136. snprintf(label, sizeof(label), "src %d", j);
  6137. ggml_graph_dump_dot_node_edge(fp, gb, node, node->src[j], label);
  6138. }
  6139. }
  6140. }
  6141. for (int i = 0; i < gb->n_leafs; i++) {
  6142. struct ggml_tensor * node = gb->leafs[i];
  6143. for (int j = 0; j < GGML_MAX_SRC; j++) {
  6144. if (node->src[j]) {
  6145. char label[16];
  6146. snprintf(label, sizeof(label), "src %d", j);
  6147. ggml_graph_dump_dot_leaf_edge(fp, node, node->src[j], label);
  6148. }
  6149. }
  6150. }
  6151. fprintf(fp, "}\n");
  6152. fclose(fp);
  6153. GGML_LOG_INFO("%s: dot -Tpng %s -o %s.png && open %s.png\n", __func__, filename, filename, filename);
  6154. }
  6155. ////////////////////////////////////////////////////////////////////////////////
  6156. void ggml_set_input(struct ggml_tensor * tensor) {
  6157. tensor->flags |= GGML_TENSOR_FLAG_INPUT;
  6158. }
  6159. void ggml_set_output(struct ggml_tensor * tensor) {
  6160. tensor->flags |= GGML_TENSOR_FLAG_OUTPUT;
  6161. }
  6162. void ggml_set_param(struct ggml_tensor * tensor) {
  6163. GGML_ASSERT(tensor->op == GGML_OP_NONE);
  6164. tensor->flags |= GGML_TENSOR_FLAG_PARAM;
  6165. }
  6166. void ggml_set_loss(struct ggml_tensor * tensor) {
  6167. GGML_ASSERT(ggml_is_scalar(tensor));
  6168. GGML_ASSERT(tensor->type == GGML_TYPE_F32);
  6169. tensor->flags |= GGML_TENSOR_FLAG_LOSS;
  6170. }
  6171. ////////////////////////////////////////////////////////////////////////////////
  6172. void ggml_quantize_init(enum ggml_type type) {
  6173. ggml_critical_section_start();
  6174. switch (type) {
  6175. case GGML_TYPE_IQ2_XXS:
  6176. case GGML_TYPE_IQ2_XS:
  6177. case GGML_TYPE_IQ2_S:
  6178. case GGML_TYPE_IQ1_S:
  6179. case GGML_TYPE_IQ1_M: iq2xs_init_impl(type); break;
  6180. case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break;
  6181. case GGML_TYPE_IQ3_S: iq3xs_init_impl(512); break;
  6182. default: // nothing
  6183. break;
  6184. }
  6185. ggml_critical_section_end();
  6186. }
  6187. void ggml_quantize_free(void) {
  6188. ggml_critical_section_start();
  6189. iq2xs_free_impl(GGML_TYPE_IQ2_XXS);
  6190. iq2xs_free_impl(GGML_TYPE_IQ2_XS);
  6191. iq2xs_free_impl(GGML_TYPE_IQ1_S);
  6192. iq3xs_free_impl(256);
  6193. ggml_critical_section_end();
  6194. }
  6195. bool ggml_quantize_requires_imatrix(enum ggml_type type) {
  6196. return
  6197. type == GGML_TYPE_IQ2_XXS ||
  6198. type == GGML_TYPE_IQ2_XS ||
  6199. type == GGML_TYPE_IQ1_S;// ||
  6200. //type == GGML_TYPE_IQ1_M;
  6201. }
  6202. size_t ggml_quantize_chunk(
  6203. enum ggml_type type,
  6204. const float * src,
  6205. void * dst,
  6206. int64_t start,
  6207. int64_t nrows,
  6208. int64_t n_per_row,
  6209. const float * imatrix) {
  6210. const int64_t n = (int64_t) nrows * n_per_row;
  6211. if (ggml_quantize_requires_imatrix(type)) {
  6212. GGML_ASSERT(imatrix != NULL);
  6213. }
  6214. GGML_ASSERT(start % type_traits[type].blck_size == 0);
  6215. GGML_ASSERT(start % n_per_row == 0);
  6216. ggml_quantize_init(type); // this is noop if already initialized
  6217. const size_t start_row = start / n_per_row;
  6218. const size_t row_size = ggml_row_size(type, n_per_row);
  6219. size_t result = 0;
  6220. switch (type) {
  6221. case GGML_TYPE_Q4_0: result = quantize_q4_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6222. case GGML_TYPE_Q4_1: result = quantize_q4_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6223. case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6224. case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6225. case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6226. case GGML_TYPE_MXFP4: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6227. case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6228. case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6229. case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6230. case GGML_TYPE_Q5_K: result = quantize_q5_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6231. case GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6232. case GGML_TYPE_TQ1_0: result = quantize_tq1_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6233. case GGML_TYPE_TQ2_0: result = quantize_tq2_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6234. case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6235. case GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6236. case GGML_TYPE_IQ3_XXS: result = quantize_iq3_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6237. case GGML_TYPE_IQ3_S: result = quantize_iq3_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6238. case GGML_TYPE_IQ2_S: result = quantize_iq2_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6239. case GGML_TYPE_IQ1_S: result = quantize_iq1_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6240. case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6241. case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6242. case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
  6243. case GGML_TYPE_F16:
  6244. {
  6245. size_t elemsize = sizeof(ggml_fp16_t);
  6246. ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n);
  6247. result = n * elemsize;
  6248. } break;
  6249. case GGML_TYPE_BF16:
  6250. {
  6251. size_t elemsize = sizeof(ggml_bf16_t);
  6252. ggml_fp32_to_bf16_row_ref(src + start, (ggml_bf16_t *)dst + start, n);
  6253. result = n * elemsize;
  6254. } break;
  6255. case GGML_TYPE_F32:
  6256. {
  6257. size_t elemsize = sizeof(float);
  6258. result = n * elemsize;
  6259. memcpy((uint8_t *)dst + start * elemsize, src + start, result);
  6260. } break;
  6261. default:
  6262. assert(false);
  6263. }
  6264. GGML_ASSERT(result == nrows * row_size);
  6265. return result;
  6266. }
  6267. ////////////////////////////////////////////////////////////////////////////////
  6268. void ggml_log_get(ggml_log_callback * log_callback, void ** user_data) {
  6269. *log_callback = g_logger_state.log_callback;
  6270. *user_data = g_logger_state.log_callback_user_data;
  6271. }
  6272. void ggml_log_set(ggml_log_callback log_callback, void * user_data) {
  6273. g_logger_state.log_callback = log_callback ? log_callback : ggml_log_callback_default;
  6274. g_logger_state.log_callback_user_data = user_data;
  6275. }
  6276. void ggml_threadpool_params_init(struct ggml_threadpool_params * p, int n_threads) {
  6277. p->n_threads = n_threads;
  6278. p->prio = 0; // default priority (usually means normal or inherited)
  6279. p->poll = 50; // hybrid-polling enabled
  6280. p->strict_cpu = false; // no strict placement (all threads share same cpumask)
  6281. p->paused = false; // threads are ready to go
  6282. memset(p->cpumask, 0, GGML_MAX_N_THREADS); // all-zero means use the default affinity (usually inherited)
  6283. }
  6284. struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads) {
  6285. struct ggml_threadpool_params p;
  6286. ggml_threadpool_params_init(&p, n_threads);
  6287. return p;
  6288. }
  6289. bool ggml_threadpool_params_match(const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1) {
  6290. if (p0->n_threads != p1->n_threads ) return false;
  6291. if (p0->prio != p1->prio ) return false;
  6292. if (p0->poll != p1->poll ) return false;
  6293. if (p0->strict_cpu != p1->strict_cpu ) return false;
  6294. return memcmp(p0->cpumask, p1->cpumask, GGML_MAX_N_THREADS) == 0;
  6295. }