| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265 |
- #include "mmq.cuh"
- #include "vecdotq.cuh"
- typedef void (*allocate_tiles_cuda_t)(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc);
- typedef void (*load_tiles_cuda_t)(
- const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
- int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row);
- typedef float (*vec_dot_q_mul_mat_cuda_t)(
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, const int & i, const int & j, const int & k);
- typedef void (*dot_kernel_k_t)(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v);
- template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
- GGML_UNUSED(x_qh);
- GGML_UNUSED(x_sc);
- __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y];
- __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI4_0) + mmq_y/QI4_0];
- *x_ql = tile_x_qs;
- *x_dm = (half2 *) tile_x_d;
- }
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
- const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
- int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
- GGML_CUDA_ASSUME(i_offset >= 0);
- GGML_CUDA_ASSUME(i_offset < nwarps);
- GGML_CUDA_ASSUME(k >= 0);
- GGML_CUDA_ASSUME(k < WARP_SIZE);
- const int kbx = k / QI4_0;
- const int kqsx = k % QI4_0;
- const block_q4_0 * bx0 = (const block_q4_0 *) vx;
- float * x_dmf = (float *) x_dm;
- #pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
- int i = i0 + i_offset;
- if (need_check) {
- i = min(i, i_max);
- }
- const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx;
- x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
- // x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d;
- }
- const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
- const int kbxd = k % blocks_per_tile_x_row;
- #pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
- int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row;
- if (need_check) {
- i = min(i, i_max);
- }
- const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd;
- x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d;
- }
- }
- static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat(
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
- const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
- const float * x_dmf = (const float *) x_dm;
- int u[2*VDR_Q4_0_Q8_1_MMQ];
- #pragma unroll
- for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
- u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
- u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE];
- }
- return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
- (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0],
- y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
- }
- template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
- __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + + mmq_y];
- __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_1) + mmq_y/QI4_1];
- *x_ql = tile_x_qs;
- *x_dm = tile_x_dm;
- }
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
- const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
- int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
- GGML_CUDA_ASSUME(i_offset >= 0);
- GGML_CUDA_ASSUME(i_offset < nwarps);
- GGML_CUDA_ASSUME(k >= 0);
- GGML_CUDA_ASSUME(k < WARP_SIZE);
- const int kbx = k / QI4_1;
- const int kqsx = k % QI4_1;
- const block_q4_1 * bx0 = (const block_q4_1 *) vx;
- #pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
- int i = i0 + i_offset;
- if (need_check) {
- i = min(i, i_max);
- }
- const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx;
- x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
- }
- const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
- const int kbxd = k % blocks_per_tile_x_row;
- #pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
- int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row;
- if (need_check) {
- i = min(i, i_max);
- }
- const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd;
- x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
- }
- }
- static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat(
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
- const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
- int u[2*VDR_Q4_1_Q8_1_MMQ];
- #pragma unroll
- for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
- u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
- u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE];
- }
- return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
- (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1],
- y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
- }
- template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
- __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y];
- __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI5_0) + mmq_y/QI5_0];
- *x_ql = tile_x_ql;
- *x_dm = (half2 *) tile_x_d;
- }
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
- const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
- int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
- GGML_CUDA_ASSUME(i_offset >= 0);
- GGML_CUDA_ASSUME(i_offset < nwarps);
- GGML_CUDA_ASSUME(k >= 0);
- GGML_CUDA_ASSUME(k < WARP_SIZE);
- const int kbx = k / QI5_0;
- const int kqsx = k % QI5_0;
- const block_q5_0 * bx0 = (const block_q5_0 *) vx;
- #pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
- int i = i0 + i_offset;
- if (need_check) {
- i = min(i, i_max);
- }
- const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbx;
- const int ql = get_int_from_uint8(bxi->qs, kqsx);
- const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (k % QI5_0));
- int qs0 = (ql >> 0) & 0x0F0F0F0F;
- qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
- qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
- qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
- qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
- qs0 = __vsubss4(qs0, 0x10101010); // subtract 16
- x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
- int qs1 = (ql >> 4) & 0x0F0F0F0F;
- qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
- qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
- qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
- qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
- qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
- x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
- }
- const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
- const int kbxd = k % blocks_per_tile_x_row;
- float * x_dmf = (float *) x_dm;
- #pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
- int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row;
- if (need_check) {
- i = min(i, i_max);
- }
- const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd;
- x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d;
- }
- }
- static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat(
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
- const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
- const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0;
- const float * x_dmf = (const float *) x_dm;
- const float * y_df = (const float *) y_ds;
- int u[2*VDR_Q5_0_Q8_1_MMQ];
- #pragma unroll
- for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
- u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
- u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE];
- }
- return vec_dot_q8_0_q8_1_impl<QR5_0*VDR_Q5_0_Q8_1_MMQ>
- (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
- }
- template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
- __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y];
- __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_1) + mmq_y/QI5_1];
- *x_ql = tile_x_ql;
- *x_dm = tile_x_dm;
- }
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
- const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
- int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
- GGML_CUDA_ASSUME(i_offset >= 0);
- GGML_CUDA_ASSUME(i_offset < nwarps);
- GGML_CUDA_ASSUME(k >= 0);
- GGML_CUDA_ASSUME(k < WARP_SIZE);
- const int kbx = k / QI5_1;
- const int kqsx = k % QI5_1;
- const block_q5_1 * bx0 = (const block_q5_1 *) vx;
- #pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
- int i = i0 + i_offset;
- if (need_check) {
- i = min(i, i_max);
- }
- const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbx;
- const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
- const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (k % QI5_1));
- int qs0 = (ql >> 0) & 0x0F0F0F0F;
- qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
- qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
- qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
- qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
- x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
- int qs1 = (ql >> 4) & 0x0F0F0F0F;
- qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
- qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
- qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
- qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
- x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
- }
- const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
- const int kbxd = k % blocks_per_tile_x_row;
- #pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
- int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row;
- if (need_check) {
- i = min(i, i_max);
- }
- const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd;
- x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
- }
- }
- static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat(
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
- const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
- const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1;
- int u[2*VDR_Q5_1_Q8_1_MMQ];
- #pragma unroll
- for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
- u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
- u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE];
- }
- return vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
- (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
- }
- template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
- __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y];
- __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI8_0) + mmq_y/QI8_0];
- *x_ql = tile_x_qs;
- *x_dm = (half2 *) tile_x_d;
- }
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
- const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
- int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
- GGML_CUDA_ASSUME(i_offset >= 0);
- GGML_CUDA_ASSUME(i_offset < nwarps);
- GGML_CUDA_ASSUME(k >= 0);
- GGML_CUDA_ASSUME(k < WARP_SIZE);
- const int kbx = k / QI8_0;
- const int kqsx = k % QI8_0;
- float * x_dmf = (float *) x_dm;
- const block_q8_0 * bx0 = (const block_q8_0 *) vx;
- #pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
- int i = i0 + i_offset;
- if (need_check) {
- i = min(i, i_max);
- }
- const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx;
- x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx);
- }
- const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
- const int kbxd = k % blocks_per_tile_x_row;
- #pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) {
- int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row;
- if (need_check) {
- i = min(i, i_max);
- }
- const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd;
- x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
- }
- }
- static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat(
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
- const float * x_dmf = (const float *) x_dm;
- const float * y_df = (const float *) y_ds;
- return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMQ>
- (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0],
- y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
- }
- template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
- GGML_UNUSED(x_qh);
- __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y];
- __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI2_K) + mmq_y/QI2_K];
- __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/4) + mmq_y/4];
- *x_ql = tile_x_ql;
- *x_dm = tile_x_dm;
- *x_sc = tile_x_sc;
- }
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
- const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
- int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
- GGML_UNUSED(x_qh);
- GGML_CUDA_ASSUME(i_offset >= 0);
- GGML_CUDA_ASSUME(i_offset < nwarps);
- GGML_CUDA_ASSUME(k >= 0);
- GGML_CUDA_ASSUME(k < WARP_SIZE);
- const int kbx = k / QI2_K;
- const int kqsx = k % QI2_K;
- const block_q2_K * bx0 = (const block_q2_K *) vx;
- #pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
- int i = i0 + i_offset;
- if (need_check) {
- i = min(i, i_max);
- }
- const block_q2_K * bxi = bx0 + i*blocks_per_row + kbx;
- x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
- }
- const int blocks_per_tile_x_row = WARP_SIZE / QI2_K;
- const int kbxd = k % blocks_per_tile_x_row;
- #pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) {
- int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % mmq_y;
- if (need_check) {
- i = min(i, i_max);
- }
- const block_q2_K * bxi = bx0 + i*blocks_per_row + kbxd;
- x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm;
- }
- #pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
- int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
- if (need_check) {
- i = min(i, i_max);
- }
- const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI2_K/4);
- x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4));
- }
- }
- static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat(
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
- GGML_UNUSED(x_qh);
- const int kbx = k / QI2_K;
- const int ky = (k % QI2_K) * QR2_K;
- const float * y_df = (const float *) y_ds;
- int v[QR2_K*VDR_Q2_K_Q8_1_MMQ];
- const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
- const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2));
- #pragma unroll
- for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) {
- v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;
- }
- const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
- const int index_y = j * WARP_SIZE + (QR2_K*k) % WARP_SIZE;
- return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]);
- }
- template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
- __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y];
- __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI3_K) + mmq_y/QI3_K];
- __shared__ int tile_x_qh[mmq_y * (WARP_SIZE/2) + mmq_y/2];
- __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/4) + mmq_y/4];
- *x_ql = tile_x_ql;
- *x_dm = tile_x_dm;
- *x_qh = tile_x_qh;
- *x_sc = tile_x_sc;
- }
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
- const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
- int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
- GGML_CUDA_ASSUME(i_offset >= 0);
- GGML_CUDA_ASSUME(i_offset < nwarps);
- GGML_CUDA_ASSUME(k >= 0);
- GGML_CUDA_ASSUME(k < WARP_SIZE);
- const int kbx = k / QI3_K;
- const int kqsx = k % QI3_K;
- const block_q3_K * bx0 = (const block_q3_K *) vx;
- #pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
- int i = i0 + i_offset;
- if (need_check) {
- i = min(i, i_max);
- }
- const block_q3_K * bxi = bx0 + i*blocks_per_row + kbx;
- x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
- }
- const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
- const int kbxd = k % blocks_per_tile_x_row;
- float * x_dmf = (float *) x_dm;
- #pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {
- int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % mmq_y;
- if (need_check) {
- i = min(i, i_max);
- }
- const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd;
- x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;
- }
- #pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) {
- int i = i0 + i_offset * 2 + k / (WARP_SIZE/2);
- if (need_check) {
- i = min(i, i_max);
- }
- const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2);
- // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
- x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2));
- }
- #pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
- int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
- if (need_check) {
- i = min(i, i_max);
- }
- const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4);
- const int ksc = k % (QI3_K/4);
- const int ksc_low = ksc % (QI3_K/8);
- const int shift_low = 4 * (ksc / (QI3_K/8));
- const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
- const int ksc_high = QI3_K/8;
- const int shift_high = 2 * ksc;
- const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
- const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
- x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = sc;
- }
- }
- static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat(
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
- const int kbx = k / QI3_K;
- const int ky = (k % QI3_K) * QR3_K;
- const float * x_dmf = (const float *) x_dm;
- const float * y_df = (const float *) y_ds;
- const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
- int v[QR3_K*VDR_Q3_K_Q8_1_MMQ];
- #pragma unroll
- for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) {
- const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
- const int shift = 2 * ((ky % 32) / 8);
- const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
- const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
- const int vlh = (vh << 2) & 0x04040404;
- v[l] = __vsubss4(vll, vlh);
- }
- const int index_y = j * WARP_SIZE + (k*QR3_K) % WARP_SIZE;
- return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]);
- }
- template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
- GGML_UNUSED(x_qh);
- __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y];
- __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_K) + mmq_y/QI4_K];
- __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8];
- *x_ql = tile_x_ql;
- *x_dm = tile_x_dm;
- *x_sc = tile_x_sc;
- }
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
- const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
- int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
- GGML_UNUSED(x_qh);
- GGML_CUDA_ASSUME(i_offset >= 0);
- GGML_CUDA_ASSUME(i_offset < nwarps);
- GGML_CUDA_ASSUME(k >= 0);
- GGML_CUDA_ASSUME(k < WARP_SIZE);
- const int kbx = k / QI4_K; // == 0 if QK_K == 256
- const int kqsx = k % QI4_K; // == k if QK_K == 256
- const block_q4_K * bx0 = (const block_q4_K *) vx;
- #pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
- int i = i0 + i_offset;
- if (need_check) {
- i = min(i, i_max);
- }
- const block_q4_K * bxi = bx0 + i*blocks_per_row + kbx;
- x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
- }
- const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256
- const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
- #pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) {
- int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % mmq_y;
- if (need_check) {
- i = min(i, i_max);
- }
- const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd;
- #if QK_K == 256
- x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
- #else
- x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = {bxi->dm[0], bxi->dm[1]};
- #endif
- }
- #pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
- int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
- if (need_check) {
- i = min(i, i_max);
- }
- const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8);
- const int * scales = (const int *) bxi->scales;
- const int ksc = k % (WARP_SIZE/8);
- // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
- int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
- scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
- x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
- }
- }
- static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat(
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
- GGML_UNUSED(x_qh);
- const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8);
- const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE;
- return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[index_y], sc, sc+8,
- x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
- }
- template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
- GGML_UNUSED(x_qh);
- __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y];
- __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_K) + mmq_y/QI5_K];
- __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8];
- *x_ql = tile_x_ql;
- *x_dm = tile_x_dm;
- *x_sc = tile_x_sc;
- }
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
- const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
- int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
- GGML_UNUSED(x_qh);
- GGML_CUDA_ASSUME(i_offset >= 0);
- GGML_CUDA_ASSUME(i_offset < nwarps);
- GGML_CUDA_ASSUME(k >= 0);
- GGML_CUDA_ASSUME(k < WARP_SIZE);
- const int kbx = k / QI5_K; // == 0 if QK_K == 256
- const int kqsx = k % QI5_K; // == k if QK_K == 256
- const block_q5_K * bx0 = (const block_q5_K *) vx;
- #pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
- int i = i0 + i_offset;
- if (need_check) {
- i = min(i, i_max);
- }
- const block_q5_K * bxi = bx0 + i*blocks_per_row + kbx;
- const int ky = QR5_K*kqsx;
- const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
- const int ql0 = (ql >> 0) & 0x0F0F0F0F;
- const int ql1 = (ql >> 4) & 0x0F0F0F0F;
- const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4));
- const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
- const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
- const int kq0 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + 0;
- const int kq1 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + (QI5_K/4);
- x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
- x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
- }
- const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256
- const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
- #pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) {
- int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % mmq_y;
- if (need_check) {
- i = min(i, i_max);
- }
- const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd;
- #if QK_K == 256
- x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
- #endif
- }
- #pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
- int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
- if (need_check) {
- i = min(i, i_max);
- }
- const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8);
- const int * scales = (const int *) bxi->scales;
- const int ksc = k % (WARP_SIZE/8);
- // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
- int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
- scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
- x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
- }
- }
- static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat(
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
- GGML_UNUSED(x_qh);
- const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8);
- const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k;
- const int index_y = j * WARP_SIZE + (QR5_K*k) % WARP_SIZE;
- return vec_dot_q5_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8,
- x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
- }
- template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
- GGML_UNUSED(x_qh);
- __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y];
- __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI6_K) + mmq_y/QI6_K];
- __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8];
- *x_ql = tile_x_ql;
- *x_dm = tile_x_dm;
- *x_sc = tile_x_sc;
- }
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
- const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
- int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
- GGML_UNUSED(x_qh);
- GGML_CUDA_ASSUME(i_offset >= 0);
- GGML_CUDA_ASSUME(i_offset < nwarps);
- GGML_CUDA_ASSUME(k >= 0);
- GGML_CUDA_ASSUME(k < WARP_SIZE);
- const int kbx = k / QI6_K; // == 0 if QK_K == 256
- const int kqsx = k % QI6_K; // == k if QK_K == 256
- const block_q6_K * bx0 = (const block_q6_K *) vx;
- #pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
- int i = i0 + i_offset;
- if (need_check) {
- i = min(i, i_max);
- }
- const block_q6_K * bxi = bx0 + i*blocks_per_row + kbx;
- const int ky = QR6_K*kqsx;
- const int ql = get_int_from_uint8(bxi->ql, kqsx);
- const int ql0 = (ql >> 0) & 0x0F0F0F0F;
- const int ql1 = (ql >> 4) & 0x0F0F0F0F;
- const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
- const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
- const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030;
- const int kq0 = ky - ky % QI6_K + k % (QI6_K/2) + 0;
- const int kq1 = ky - ky % QI6_K + k % (QI6_K/2) + (QI6_K/2);
- x_ql[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
- x_ql[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
- }
- const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
- const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
- float * x_dmf = (float *) x_dm;
- #pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
- int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % mmq_y;
- if (need_check) {
- i = min(i, i_max);
- }
- const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd;
- x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d;
- }
- #pragma unroll
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
- int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
- if (need_check) {
- i = min(i, i_max);
- }
- const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / 4;
- x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8));
- }
- }
- static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
- const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
- GGML_UNUSED(x_qh);
- const float * x_dmf = (const float *) x_dm;
- const float * y_df = (const float *) y_ds;
- const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]);
- const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k;
- const int index_y = j * WARP_SIZE + (QR6_K*k) % WARP_SIZE;
- return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
- }
- #define MMQ_X_Q4_0_RDNA2 64
- #define MMQ_Y_Q4_0_RDNA2 128
- #define NWARPS_Q4_0_RDNA2 8
- #define MMQ_X_Q4_0_RDNA1 64
- #define MMQ_Y_Q4_0_RDNA1 64
- #define NWARPS_Q4_0_RDNA1 8
- #if defined(CUDA_USE_TENSOR_CORES)
- #define MMQ_X_Q4_0_AMPERE 4
- #define MMQ_Y_Q4_0_AMPERE 32
- #define NWARPS_Q4_0_AMPERE 4
- #else
- #define MMQ_X_Q4_0_AMPERE 64
- #define MMQ_Y_Q4_0_AMPERE 128
- #define NWARPS_Q4_0_AMPERE 4
- #endif
- #define MMQ_X_Q4_0_PASCAL 64
- #define MMQ_Y_Q4_0_PASCAL 64
- #define NWARPS_Q4_0_PASCAL 8
- template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
- allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
- static __device__ __forceinline__ void mul_mat_q(
- const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
- const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
- const block_q_t * x = (const block_q_t *) vx;
- const block_q8_1 * y = (const block_q8_1 *) vy;
- const int blocks_per_row_x = ncols_x / qk;
- const int blocks_per_col_y = nrows_y / QK8_1;
- const int blocks_per_warp = WARP_SIZE / qi;
- const int & ncols_dst = ncols_y;
- const int row_dst_0 = blockIdx.x*mmq_y;
- const int & row_x_0 = row_dst_0;
- const int col_dst_0 = blockIdx.y*mmq_x;
- const int & col_y_0 = col_dst_0;
- int * tile_x_ql = nullptr;
- half2 * tile_x_dm = nullptr;
- int * tile_x_qh = nullptr;
- int * tile_x_sc = nullptr;
- allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc);
- __shared__ int tile_y_qs[mmq_x * WARP_SIZE];
- __shared__ half2 tile_y_ds[mmq_x * WARP_SIZE/QI8_1];
- float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}};
- for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
- load_tiles(x + row_x_0*blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc,
- threadIdx.y, nrows_x-row_x_0-1, threadIdx.x, blocks_per_row_x);
- #pragma unroll
- for (int ir = 0; ir < qr; ++ir) {
- const int kqs = ir*WARP_SIZE + threadIdx.x;
- const int kbxd = kqs / QI8_1;
- #pragma unroll
- for (int i = 0; i < mmq_x; i += nwarps) {
- const int col_y_eff = min(col_y_0 + threadIdx.y + i, ncols_y-1); // to prevent out-of-bounds memory accesses
- const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd];
- const int index_y = (threadIdx.y + i) * WARP_SIZE + kqs % WARP_SIZE;
- tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1);
- }
- #pragma unroll
- for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
- const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x;
- const int kby = threadIdx.x % (WARP_SIZE/QI8_1);
- const int col_y_eff = min(col_y_0 + ids, ncols_y-1);
- // if the sum is not needed it's faster to transform the scale to f32 ahead of time
- const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE/QI8_1) + kby].ds;
- half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby];
- if (need_sum) {
- *dsi_dst = *dsi_src;
- } else {
- float * dfi_dst = (float *) dsi_dst;
- *dfi_dst = __low2float(*dsi_src);
- }
- }
- __syncthreads();
- // #pragma unroll // unrolling this loop causes too much register pressure
- for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) {
- #pragma unroll
- for (int j = 0; j < mmq_x; j += nwarps) {
- #pragma unroll
- for (int i = 0; i < mmq_y; i += WARP_SIZE) {
- sum[i/WARP_SIZE][j/nwarps] += vec_dot(
- tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds,
- threadIdx.x + i, threadIdx.y + j, k);
- }
- }
- }
- __syncthreads();
- }
- }
- #pragma unroll
- for (int j = 0; j < mmq_x; j += nwarps) {
- const int col_dst = col_dst_0 + j + threadIdx.y;
- if (col_dst >= ncols_dst) {
- return;
- }
- #pragma unroll
- for (int i = 0; i < mmq_y; i += WARP_SIZE) {
- const int row_dst = row_dst_0 + threadIdx.x + i;
- if (row_dst >= nrows_dst) {
- continue;
- }
- dst[col_dst*nrows_dst + row_dst] = sum[i/WARP_SIZE][j/nwarps];
- }
- }
- }
- template <bool need_check> static __global__ void
- #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- #if defined(RDNA3) || defined(RDNA2)
- __launch_bounds__(WARP_SIZE*NWARPS_Q4_0_RDNA2, 2)
- #endif // defined(RDNA3) || defined(RDNA2)
- #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- mul_mat_q4_0(
- const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
- const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
- #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- #if defined(RDNA3) || defined(RDNA2)
- const int mmq_x = MMQ_X_Q4_0_RDNA2;
- const int mmq_y = MMQ_Y_Q4_0_RDNA2;
- const int nwarps = NWARPS_Q4_0_RDNA2;
- #else
- const int mmq_x = MMQ_X_Q4_0_RDNA1;
- const int mmq_y = MMQ_Y_Q4_0_RDNA1;
- const int nwarps = NWARPS_Q4_0_RDNA1;
- #endif // defined(RDNA3) || defined(RDNA2)
- mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
- load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- #elif __CUDA_ARCH__ >= CC_VOLTA
- const int mmq_x = MMQ_X_Q4_0_AMPERE;
- const int mmq_y = MMQ_Y_Q4_0_AMPERE;
- const int nwarps = NWARPS_Q4_0_AMPERE;
- mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
- load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- #elif __CUDA_ARCH__ >= MIN_CC_DP4A
- const int mmq_x = MMQ_X_Q4_0_PASCAL;
- const int mmq_y = MMQ_Y_Q4_0_PASCAL;
- const int nwarps = NWARPS_Q4_0_PASCAL;
- mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
- load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- #else
- GGML_UNUSED(vec_dot_q4_0_q8_1_mul_mat);
- NO_DEVICE_CODE;
- #endif // __CUDA_ARCH__ >= CC_VOLTA
- }
- #define MMQ_X_Q4_1_RDNA2 64
- #define MMQ_Y_Q4_1_RDNA2 128
- #define NWARPS_Q4_1_RDNA2 8
- #define MMQ_X_Q4_1_RDNA1 64
- #define MMQ_Y_Q4_1_RDNA1 64
- #define NWARPS_Q4_1_RDNA1 8
- #if defined(CUDA_USE_TENSOR_CORES)
- #define MMQ_X_Q4_1_AMPERE 4
- #define MMQ_Y_Q4_1_AMPERE 32
- #define NWARPS_Q4_1_AMPERE 4
- #else
- #define MMQ_X_Q4_1_AMPERE 64
- #define MMQ_Y_Q4_1_AMPERE 128
- #define NWARPS_Q4_1_AMPERE 4
- #endif
- #define MMQ_X_Q4_1_PASCAL 64
- #define MMQ_Y_Q4_1_PASCAL 64
- #define NWARPS_Q4_1_PASCAL 8
- template <bool need_check> static __global__ void
- #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- #if defined(RDNA3) || defined(RDNA2)
- __launch_bounds__(WARP_SIZE*NWARPS_Q4_1_RDNA2, 2)
- #endif // defined(RDNA3) || defined(RDNA2)
- #elif __CUDA_ARCH__ < CC_VOLTA
- __launch_bounds__(WARP_SIZE*NWARPS_Q4_1_PASCAL, 2)
- #endif // __CUDA_ARCH__ < CC_VOLTA
- mul_mat_q4_1(
- const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
- const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
- #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- #if defined(RDNA3) || defined(RDNA2)
- const int mmq_x = MMQ_X_Q4_1_RDNA2;
- const int mmq_y = MMQ_Y_Q4_1_RDNA2;
- const int nwarps = NWARPS_Q4_1_RDNA2;
- #else
- const int mmq_x = MMQ_X_Q4_1_RDNA1;
- const int mmq_y = MMQ_Y_Q4_1_RDNA1;
- const int nwarps = NWARPS_Q4_1_RDNA1;
- #endif // defined(RDNA3) || defined(RDNA2)
- mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
- load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- #elif __CUDA_ARCH__ >= CC_VOLTA
- const int mmq_x = MMQ_X_Q4_1_AMPERE;
- const int mmq_y = MMQ_Y_Q4_1_AMPERE;
- const int nwarps = NWARPS_Q4_1_AMPERE;
- mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
- load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- #elif __CUDA_ARCH__ >= MIN_CC_DP4A
- const int mmq_x = MMQ_X_Q4_1_PASCAL;
- const int mmq_y = MMQ_Y_Q4_1_PASCAL;
- const int nwarps = NWARPS_Q4_1_PASCAL;
- mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
- load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- #else
- GGML_UNUSED(vec_dot_q4_1_q8_1_mul_mat);
- NO_DEVICE_CODE;
- #endif // __CUDA_ARCH__ >= CC_VOLTA
- }
- #define MMQ_X_Q5_0_RDNA2 64
- #define MMQ_Y_Q5_0_RDNA2 128
- #define NWARPS_Q5_0_RDNA2 8
- #define MMQ_X_Q5_0_RDNA1 64
- #define MMQ_Y_Q5_0_RDNA1 64
- #define NWARPS_Q5_0_RDNA1 8
- #if defined(CUDA_USE_TENSOR_CORES)
- #define MMQ_X_Q5_0_AMPERE 4
- #define MMQ_Y_Q5_0_AMPERE 32
- #define NWARPS_Q5_0_AMPERE 4
- #else
- #define MMQ_X_Q5_0_AMPERE 128
- #define MMQ_Y_Q5_0_AMPERE 64
- #define NWARPS_Q5_0_AMPERE 4
- #endif
- #define MMQ_X_Q5_0_PASCAL 64
- #define MMQ_Y_Q5_0_PASCAL 64
- #define NWARPS_Q5_0_PASCAL 8
- template <bool need_check> static __global__ void
- #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- #if defined(RDNA3) || defined(RDNA2)
- __launch_bounds__(WARP_SIZE*NWARPS_Q5_0_RDNA2, 2)
- #endif // defined(RDNA3) || defined(RDNA2)
- #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- mul_mat_q5_0(
- const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
- const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
- #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- #if defined(RDNA3) || defined(RDNA2)
- const int mmq_x = MMQ_X_Q5_0_RDNA2;
- const int mmq_y = MMQ_Y_Q5_0_RDNA2;
- const int nwarps = NWARPS_Q5_0_RDNA2;
- #else
- const int mmq_x = MMQ_X_Q5_0_RDNA1;
- const int mmq_y = MMQ_Y_Q5_0_RDNA1;
- const int nwarps = NWARPS_Q5_0_RDNA1;
- #endif // defined(RDNA3) || defined(RDNA2)
- mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
- load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- #elif __CUDA_ARCH__ >= CC_VOLTA
- const int mmq_x = MMQ_X_Q5_0_AMPERE;
- const int mmq_y = MMQ_Y_Q5_0_AMPERE;
- const int nwarps = NWARPS_Q5_0_AMPERE;
- mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
- load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- #elif __CUDA_ARCH__ >= MIN_CC_DP4A
- const int mmq_x = MMQ_X_Q5_0_PASCAL;
- const int mmq_y = MMQ_Y_Q5_0_PASCAL;
- const int nwarps = NWARPS_Q5_0_PASCAL;
- mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
- load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- #else
- GGML_UNUSED(vec_dot_q5_0_q8_1_mul_mat);
- NO_DEVICE_CODE;
- #endif // __CUDA_ARCH__ >= CC_VOLTA
- }
- #define MMQ_X_Q5_1_RDNA2 64
- #define MMQ_Y_Q5_1_RDNA2 128
- #define NWARPS_Q5_1_RDNA2 8
- #define MMQ_X_Q5_1_RDNA1 64
- #define MMQ_Y_Q5_1_RDNA1 64
- #define NWARPS_Q5_1_RDNA1 8
- #if defined(CUDA_USE_TENSOR_CORES)
- #define MMQ_X_Q5_1_AMPERE 4
- #define MMQ_Y_Q5_1_AMPERE 32
- #define NWARPS_Q5_1_AMPERE 4
- #else
- #define MMQ_X_Q5_1_AMPERE 128
- #define MMQ_Y_Q5_1_AMPERE 64
- #define NWARPS_Q5_1_AMPERE 4
- #endif
- #define MMQ_X_Q5_1_PASCAL 64
- #define MMQ_Y_Q5_1_PASCAL 64
- #define NWARPS_Q5_1_PASCAL 8
- template <bool need_check> static __global__ void
- #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- #if defined(RDNA3) || defined(RDNA2)
- __launch_bounds__(WARP_SIZE*NWARPS_Q5_1_RDNA2, 2)
- #endif // defined(RDNA3) || defined(RDNA2)
- #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- mul_mat_q5_1(
- const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
- const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
- #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- #if defined(RDNA3) || defined(RDNA2)
- const int mmq_x = MMQ_X_Q5_1_RDNA2;
- const int mmq_y = MMQ_Y_Q5_1_RDNA2;
- const int nwarps = NWARPS_Q5_1_RDNA2;
- #else
- const int mmq_x = MMQ_X_Q5_1_RDNA1;
- const int mmq_y = MMQ_Y_Q5_1_RDNA1;
- const int nwarps = NWARPS_Q5_1_RDNA1;
- #endif // defined(RDNA3) || defined(RDNA2)
- mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
- load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- #elif __CUDA_ARCH__ >= CC_VOLTA
- const int mmq_x = MMQ_X_Q5_1_AMPERE;
- const int mmq_y = MMQ_Y_Q5_1_AMPERE;
- const int nwarps = NWARPS_Q5_1_AMPERE;
- mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
- load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- #elif __CUDA_ARCH__ >= MIN_CC_DP4A
- const int mmq_x = MMQ_X_Q5_1_PASCAL;
- const int mmq_y = MMQ_Y_Q5_1_PASCAL;
- const int nwarps = NWARPS_Q5_1_PASCAL;
- mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
- load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- #else
- GGML_UNUSED(vec_dot_q5_1_q8_1_mul_mat);
- NO_DEVICE_CODE;
- #endif // __CUDA_ARCH__ >= CC_VOLTA
- }
- #define MMQ_X_Q8_0_RDNA2 64
- #define MMQ_Y_Q8_0_RDNA2 128
- #define NWARPS_Q8_0_RDNA2 8
- #define MMQ_X_Q8_0_RDNA1 64
- #define MMQ_Y_Q8_0_RDNA1 64
- #define NWARPS_Q8_0_RDNA1 8
- #if defined(CUDA_USE_TENSOR_CORES)
- #define MMQ_X_Q8_0_AMPERE 4
- #define MMQ_Y_Q8_0_AMPERE 32
- #define NWARPS_Q8_0_AMPERE 4
- #else
- #define MMQ_X_Q8_0_AMPERE 128
- #define MMQ_Y_Q8_0_AMPERE 64
- #define NWARPS_Q8_0_AMPERE 4
- #endif
- #define MMQ_X_Q8_0_PASCAL 64
- #define MMQ_Y_Q8_0_PASCAL 64
- #define NWARPS_Q8_0_PASCAL 8
- template <bool need_check> static __global__ void
- #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- #if defined(RDNA3) || defined(RDNA2)
- __launch_bounds__(WARP_SIZE*NWARPS_Q8_0_RDNA2, 2)
- #endif // defined(RDNA3) || defined(RDNA2)
- #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- mul_mat_q8_0(
- const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
- const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
- #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- #if defined(RDNA3) || defined(RDNA2)
- const int mmq_x = MMQ_X_Q8_0_RDNA2;
- const int mmq_y = MMQ_Y_Q8_0_RDNA2;
- const int nwarps = NWARPS_Q8_0_RDNA2;
- #else
- const int mmq_x = MMQ_X_Q8_0_RDNA1;
- const int mmq_y = MMQ_Y_Q8_0_RDNA1;
- const int nwarps = NWARPS_Q8_0_RDNA1;
- #endif // defined(RDNA3) || defined(RDNA2)
- mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
- load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- #elif __CUDA_ARCH__ >= CC_VOLTA
- const int mmq_x = MMQ_X_Q8_0_AMPERE;
- const int mmq_y = MMQ_Y_Q8_0_AMPERE;
- const int nwarps = NWARPS_Q8_0_AMPERE;
- mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
- load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- #elif __CUDA_ARCH__ >= MIN_CC_DP4A
- const int mmq_x = MMQ_X_Q8_0_PASCAL;
- const int mmq_y = MMQ_Y_Q8_0_PASCAL;
- const int nwarps = NWARPS_Q8_0_PASCAL;
- mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
- load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- #else
- GGML_UNUSED(vec_dot_q8_0_q8_1_mul_mat);
- NO_DEVICE_CODE;
- #endif // __CUDA_ARCH__ >= CC_VOLTA
- }
- #define MMQ_X_Q2_K_RDNA2 64
- #define MMQ_Y_Q2_K_RDNA2 128
- #define NWARPS_Q2_K_RDNA2 8
- #define MMQ_X_Q2_K_RDNA1 128
- #define MMQ_Y_Q2_K_RDNA1 32
- #define NWARPS_Q2_K_RDNA1 8
- #if defined(CUDA_USE_TENSOR_CORES)
- #define MMQ_X_Q2_K_AMPERE 4
- #define MMQ_Y_Q2_K_AMPERE 32
- #define NWARPS_Q2_K_AMPERE 4
- #else
- #define MMQ_X_Q2_K_AMPERE 64
- #define MMQ_Y_Q2_K_AMPERE 128
- #define NWARPS_Q2_K_AMPERE 4
- #endif
- #define MMQ_X_Q2_K_PASCAL 64
- #define MMQ_Y_Q2_K_PASCAL 64
- #define NWARPS_Q2_K_PASCAL 8
- template <bool need_check> static __global__ void
- #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- #if defined(RDNA3) || defined(RDNA2)
- __launch_bounds__(WARP_SIZE*NWARPS_Q2_K_RDNA2, 2)
- #endif // defined(RDNA3) || defined(RDNA2)
- #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- mul_mat_q2_K(
- const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
- const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
- #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- #if defined(RDNA3) || defined(RDNA2)
- const int mmq_x = MMQ_X_Q2_K_RDNA2;
- const int mmq_y = MMQ_Y_Q2_K_RDNA2;
- const int nwarps = NWARPS_Q2_K_RDNA2;
- #else
- const int mmq_x = MMQ_X_Q2_K_RDNA1;
- const int mmq_y = MMQ_Y_Q2_K_RDNA1;
- const int nwarps = NWARPS_Q2_K_RDNA1;
- #endif // defined(RDNA3) || defined(RDNA2)
- mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
- load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- #elif __CUDA_ARCH__ >= CC_VOLTA
- const int mmq_x = MMQ_X_Q2_K_AMPERE;
- const int mmq_y = MMQ_Y_Q2_K_AMPERE;
- const int nwarps = NWARPS_Q2_K_AMPERE;
- mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
- load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- #elif __CUDA_ARCH__ >= MIN_CC_DP4A
- const int mmq_x = MMQ_X_Q2_K_PASCAL;
- const int mmq_y = MMQ_Y_Q2_K_PASCAL;
- const int nwarps = NWARPS_Q2_K_PASCAL;
- mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
- load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- #else
- GGML_UNUSED(vec_dot_q2_K_q8_1_mul_mat);
- NO_DEVICE_CODE;
- #endif // __CUDA_ARCH__ >= CC_VOLTA
- }
- #define MMQ_X_Q3_K_RDNA2 128
- #define MMQ_Y_Q3_K_RDNA2 64
- #define NWARPS_Q3_K_RDNA2 8
- #define MMQ_X_Q3_K_RDNA1 32
- #define MMQ_Y_Q3_K_RDNA1 128
- #define NWARPS_Q3_K_RDNA1 8
- #if defined(CUDA_USE_TENSOR_CORES)
- #define MMQ_X_Q3_K_AMPERE 4
- #define MMQ_Y_Q3_K_AMPERE 32
- #define NWARPS_Q3_K_AMPERE 4
- #else
- #define MMQ_X_Q3_K_AMPERE 128
- #define MMQ_Y_Q3_K_AMPERE 128
- #define NWARPS_Q3_K_AMPERE 4
- #endif
- #define MMQ_X_Q3_K_PASCAL 64
- #define MMQ_Y_Q3_K_PASCAL 64
- #define NWARPS_Q3_K_PASCAL 8
- template <bool need_check> static __global__ void
- #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- #if defined(RDNA3) || defined(RDNA2)
- __launch_bounds__(WARP_SIZE*NWARPS_Q3_K_RDNA2, 2)
- #endif // defined(RDNA3) || defined(RDNA2)
- #elif __CUDA_ARCH__ < CC_VOLTA
- __launch_bounds__(WARP_SIZE*NWARPS_Q3_K_PASCAL, 2)
- #endif // __CUDA_ARCH__ < CC_VOLTA
- mul_mat_q3_K(
- const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
- const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
- #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- #if defined(RDNA3) || defined(RDNA2)
- const int mmq_x = MMQ_X_Q3_K_RDNA2;
- const int mmq_y = MMQ_Y_Q3_K_RDNA2;
- const int nwarps = NWARPS_Q3_K_RDNA2;
- #else
- const int mmq_x = MMQ_X_Q3_K_RDNA1;
- const int mmq_y = MMQ_Y_Q3_K_RDNA1;
- const int nwarps = NWARPS_Q3_K_RDNA1;
- #endif // defined(RDNA3) || defined(RDNA2)
- mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
- load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- #elif __CUDA_ARCH__ >= CC_VOLTA
- const int mmq_x = MMQ_X_Q3_K_AMPERE;
- const int mmq_y = MMQ_Y_Q3_K_AMPERE;
- const int nwarps = NWARPS_Q3_K_AMPERE;
- mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
- load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- #elif __CUDA_ARCH__ >= MIN_CC_DP4A
- const int mmq_x = MMQ_X_Q3_K_PASCAL;
- const int mmq_y = MMQ_Y_Q3_K_PASCAL;
- const int nwarps = NWARPS_Q3_K_PASCAL;
- mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
- load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- #else
- GGML_UNUSED(vec_dot_q3_K_q8_1_mul_mat);
- NO_DEVICE_CODE;
- #endif // __CUDA_ARCH__ >= CC_VOLTA
- }
- #define MMQ_X_Q4_K_RDNA2 64
- #define MMQ_Y_Q4_K_RDNA2 128
- #define NWARPS_Q4_K_RDNA2 8
- #define MMQ_X_Q4_K_RDNA1 32
- #define MMQ_Y_Q4_K_RDNA1 64
- #define NWARPS_Q4_K_RDNA1 8
- #if defined(CUDA_USE_TENSOR_CORES)
- #define MMQ_X_Q4_K_AMPERE 4
- #define MMQ_Y_Q4_K_AMPERE 32
- #define NWARPS_Q4_K_AMPERE 4
- #else
- #define MMQ_X_Q4_K_AMPERE 64
- #define MMQ_Y_Q4_K_AMPERE 128
- #define NWARPS_Q4_K_AMPERE 4
- #endif
- #define MMQ_X_Q4_K_PASCAL 64
- #define MMQ_Y_Q4_K_PASCAL 64
- #define NWARPS_Q4_K_PASCAL 8
- template <bool need_check> static __global__ void
- #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- #if defined(RDNA3) || defined(RDNA2)
- __launch_bounds__(WARP_SIZE*NWARPS_Q4_K_RDNA2, 2)
- #endif // defined(RDNA3) || defined(RDNA2)
- #elif __CUDA_ARCH__ < CC_VOLTA
- __launch_bounds__(WARP_SIZE*NWARPS_Q4_K_PASCAL, 2)
- #endif // __CUDA_ARCH__ < CC_VOLTA
- mul_mat_q4_K(
- const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
- const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
- #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- #if defined(RDNA3) || defined(RDNA2)
- const int mmq_x = MMQ_X_Q4_K_RDNA2;
- const int mmq_y = MMQ_Y_Q4_K_RDNA2;
- const int nwarps = NWARPS_Q4_K_RDNA2;
- #else
- const int mmq_x = MMQ_X_Q4_K_RDNA1;
- const int mmq_y = MMQ_Y_Q4_K_RDNA1;
- const int nwarps = NWARPS_Q4_K_RDNA1;
- #endif // defined(RDNA3) || defined(RDNA2)
- mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
- load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- #elif __CUDA_ARCH__ >= CC_VOLTA
- const int mmq_x = MMQ_X_Q4_K_AMPERE;
- const int mmq_y = MMQ_Y_Q4_K_AMPERE;
- const int nwarps = NWARPS_Q4_K_AMPERE;
- mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
- load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- #elif __CUDA_ARCH__ >= MIN_CC_DP4A
- const int mmq_x = MMQ_X_Q4_K_PASCAL;
- const int mmq_y = MMQ_Y_Q4_K_PASCAL;
- const int nwarps = NWARPS_Q4_K_PASCAL;
- mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
- load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- #else
- GGML_UNUSED(vec_dot_q4_K_q8_1_mul_mat);
- NO_DEVICE_CODE;
- #endif // __CUDA_ARCH__ >= CC_VOLTA
- }
- #define MMQ_X_Q5_K_RDNA2 64
- #define MMQ_Y_Q5_K_RDNA2 128
- #define NWARPS_Q5_K_RDNA2 8
- #define MMQ_X_Q5_K_RDNA1 32
- #define MMQ_Y_Q5_K_RDNA1 64
- #define NWARPS_Q5_K_RDNA1 8
- #if defined(CUDA_USE_TENSOR_CORES)
- #define MMQ_X_Q5_K_AMPERE 4
- #define MMQ_Y_Q5_K_AMPERE 32
- #define NWARPS_Q5_K_AMPERE 4
- #else
- #define MMQ_X_Q5_K_AMPERE 64
- #define MMQ_Y_Q5_K_AMPERE 128
- #define NWARPS_Q5_K_AMPERE 4
- #endif
- #define MMQ_X_Q5_K_PASCAL 64
- #define MMQ_Y_Q5_K_PASCAL 64
- #define NWARPS_Q5_K_PASCAL 8
- template <bool need_check> static __global__ void
- #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- #if defined(RDNA3) || defined(RDNA2)
- __launch_bounds__(WARP_SIZE*NWARPS_Q5_K_RDNA2, 2)
- #endif // defined(RDNA3) || defined(RDNA2)
- #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- mul_mat_q5_K(
- const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
- const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
- #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- #if defined(RDNA3) || defined(RDNA2)
- const int mmq_x = MMQ_X_Q5_K_RDNA2;
- const int mmq_y = MMQ_Y_Q5_K_RDNA2;
- const int nwarps = NWARPS_Q5_K_RDNA2;
- #else
- const int mmq_x = MMQ_X_Q5_K_RDNA1;
- const int mmq_y = MMQ_Y_Q5_K_RDNA1;
- const int nwarps = NWARPS_Q5_K_RDNA1;
- #endif // defined(RDNA3) || defined(RDNA2)
- mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
- load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- #elif __CUDA_ARCH__ >= CC_VOLTA
- const int mmq_x = MMQ_X_Q5_K_AMPERE;
- const int mmq_y = MMQ_Y_Q5_K_AMPERE;
- const int nwarps = NWARPS_Q5_K_AMPERE;
- mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
- load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- #elif __CUDA_ARCH__ >= MIN_CC_DP4A
- const int mmq_x = MMQ_X_Q5_K_PASCAL;
- const int mmq_y = MMQ_Y_Q5_K_PASCAL;
- const int nwarps = NWARPS_Q5_K_PASCAL;
- mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
- load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- #else
- GGML_UNUSED(vec_dot_q5_K_q8_1_mul_mat);
- NO_DEVICE_CODE;
- #endif // __CUDA_ARCH__ >= CC_VOLTA
- }
- #define MMQ_X_Q6_K_RDNA2 64
- #define MMQ_Y_Q6_K_RDNA2 128
- #define NWARPS_Q6_K_RDNA2 8
- #define MMQ_X_Q6_K_RDNA1 32
- #define MMQ_Y_Q6_K_RDNA1 64
- #define NWARPS_Q6_K_RDNA1 8
- #if defined(CUDA_USE_TENSOR_CORES)
- #define MMQ_X_Q6_K_AMPERE 4
- #define MMQ_Y_Q6_K_AMPERE 32
- #define NWARPS_Q6_K_AMPERE 4
- #else
- #define MMQ_X_Q6_K_AMPERE 64
- #define MMQ_Y_Q6_K_AMPERE 64
- #define NWARPS_Q6_K_AMPERE 4
- #endif
- #define MMQ_X_Q6_K_PASCAL 64
- #define MMQ_Y_Q6_K_PASCAL 64
- #define NWARPS_Q6_K_PASCAL 8
- template <bool need_check> static __global__ void
- #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- #if defined(RDNA3) || defined(RDNA2)
- __launch_bounds__(WARP_SIZE*NWARPS_Q6_K_RDNA2, 2)
- #endif // defined(RDNA3) || defined(RDNA2)
- #elif __CUDA_ARCH__ < CC_VOLTA
- __launch_bounds__(WARP_SIZE*NWARPS_Q6_K_PASCAL, 2)
- #endif // __CUDA_ARCH__ < CC_VOLTA
- mul_mat_q6_K(
- const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
- const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
- #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- #if defined(RDNA3) || defined(RDNA2)
- const int mmq_x = MMQ_X_Q6_K_RDNA2;
- const int mmq_y = MMQ_Y_Q6_K_RDNA2;
- const int nwarps = NWARPS_Q6_K_RDNA2;
- #else
- const int mmq_x = MMQ_X_Q6_K_RDNA1;
- const int mmq_y = MMQ_Y_Q6_K_RDNA1;
- const int nwarps = NWARPS_Q6_K_RDNA1;
- #endif // defined(RDNA3) || defined(RDNA2)
- mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
- load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- #elif __CUDA_ARCH__ >= CC_VOLTA
- const int mmq_x = MMQ_X_Q6_K_AMPERE;
- const int mmq_y = MMQ_Y_Q6_K_AMPERE;
- const int nwarps = NWARPS_Q6_K_AMPERE;
- mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
- load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- #elif __CUDA_ARCH__ >= MIN_CC_DP4A
- const int mmq_x = MMQ_X_Q6_K_PASCAL;
- const int mmq_y = MMQ_Y_Q6_K_PASCAL;
- const int nwarps = NWARPS_Q6_K_PASCAL;
- mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
- load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- #else
- GGML_UNUSED(vec_dot_q6_K_q8_1_mul_mat);
- NO_DEVICE_CODE;
- #endif // __CUDA_ARCH__ >= CC_VOLTA
- }
- static void ggml_mul_mat_q4_0_q8_1_cuda(
- const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
- const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
- const int compute_capability = ggml_cuda_info().devices[id].cc;
- int mmq_x, mmq_y, nwarps;
- if (compute_capability >= CC_RDNA2) {
- mmq_x = MMQ_X_Q4_0_RDNA2;
- mmq_y = MMQ_Y_Q4_0_RDNA2;
- nwarps = NWARPS_Q4_0_RDNA2;
- } else if (compute_capability >= CC_OFFSET_AMD) {
- mmq_x = MMQ_X_Q4_0_RDNA1;
- mmq_y = MMQ_Y_Q4_0_RDNA1;
- nwarps = NWARPS_Q4_0_RDNA1;
- } else if (compute_capability >= CC_VOLTA) {
- mmq_x = MMQ_X_Q4_0_AMPERE;
- mmq_y = MMQ_Y_Q4_0_AMPERE;
- nwarps = NWARPS_Q4_0_AMPERE;
- } else if (compute_capability >= MIN_CC_DP4A) {
- mmq_x = MMQ_X_Q4_0_PASCAL;
- mmq_y = MMQ_Y_Q4_0_PASCAL;
- nwarps = NWARPS_Q4_0_PASCAL;
- } else {
- GGML_ASSERT(false);
- }
- const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
- const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
- const dim3 block_nums(block_num_x, block_num_y, 1);
- const dim3 block_dims(WARP_SIZE, nwarps, 1);
- if (nrows_x % mmq_y == 0) {
- const bool need_check = false;
- mul_mat_q4_0<need_check><<<block_nums, block_dims, 0, stream>>>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- } else {
- const bool need_check = true;
- mul_mat_q4_0<need_check><<<block_nums, block_dims, 0, stream>>>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- }
- }
- static void ggml_mul_mat_q4_1_q8_1_cuda(
- const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
- const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
- const int compute_capability = ggml_cuda_info().devices[id].cc;
- int mmq_x, mmq_y, nwarps;
- if (compute_capability >= CC_RDNA2) {
- mmq_x = MMQ_X_Q4_1_RDNA2;
- mmq_y = MMQ_Y_Q4_1_RDNA2;
- nwarps = NWARPS_Q4_1_RDNA2;
- } else if (compute_capability >= CC_OFFSET_AMD) {
- mmq_x = MMQ_X_Q4_1_RDNA1;
- mmq_y = MMQ_Y_Q4_1_RDNA1;
- nwarps = NWARPS_Q4_1_RDNA1;
- } else if (compute_capability >= CC_VOLTA) {
- mmq_x = MMQ_X_Q4_1_AMPERE;
- mmq_y = MMQ_Y_Q4_1_AMPERE;
- nwarps = NWARPS_Q4_1_AMPERE;
- } else if (compute_capability >= MIN_CC_DP4A) {
- mmq_x = MMQ_X_Q4_1_PASCAL;
- mmq_y = MMQ_Y_Q4_1_PASCAL;
- nwarps = NWARPS_Q4_1_PASCAL;
- } else {
- GGML_ASSERT(false);
- }
- const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
- const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
- const dim3 block_nums(block_num_x, block_num_y, 1);
- const dim3 block_dims(WARP_SIZE, nwarps, 1);
- if (nrows_x % mmq_y == 0) {
- const bool need_check = false;
- mul_mat_q4_1<need_check><<<block_nums, block_dims, 0, stream>>>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- } else {
- const bool need_check = true;
- mul_mat_q4_1<need_check><<<block_nums, block_dims, 0, stream>>>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- }
- }
- static void ggml_mul_mat_q5_0_q8_1_cuda(
- const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
- const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
- const int compute_capability = ggml_cuda_info().devices[id].cc;
- int mmq_x, mmq_y, nwarps;
- if (compute_capability >= CC_RDNA2) {
- mmq_x = MMQ_X_Q5_0_RDNA2;
- mmq_y = MMQ_Y_Q5_0_RDNA2;
- nwarps = NWARPS_Q5_0_RDNA2;
- } else if (compute_capability >= CC_OFFSET_AMD) {
- mmq_x = MMQ_X_Q5_0_RDNA1;
- mmq_y = MMQ_Y_Q5_0_RDNA1;
- nwarps = NWARPS_Q5_0_RDNA1;
- } else if (compute_capability >= CC_VOLTA) {
- mmq_x = MMQ_X_Q5_0_AMPERE;
- mmq_y = MMQ_Y_Q5_0_AMPERE;
- nwarps = NWARPS_Q5_0_AMPERE;
- } else if (compute_capability >= MIN_CC_DP4A) {
- mmq_x = MMQ_X_Q5_0_PASCAL;
- mmq_y = MMQ_Y_Q5_0_PASCAL;
- nwarps = NWARPS_Q5_0_PASCAL;
- } else {
- GGML_ASSERT(false);
- }
- const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
- const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
- const dim3 block_nums(block_num_x, block_num_y, 1);
- const dim3 block_dims(WARP_SIZE, nwarps, 1);
- if (nrows_x % mmq_y == 0) {
- const bool need_check = false;
- mul_mat_q5_0<need_check><<<block_nums, block_dims, 0, stream>>>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- } else {
- const bool need_check = true;
- mul_mat_q5_0<need_check><<<block_nums, block_dims, 0, stream>>>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- }
- }
- static void ggml_mul_mat_q5_1_q8_1_cuda(
- const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
- const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
- const int compute_capability = ggml_cuda_info().devices[id].cc;
- int mmq_x, mmq_y, nwarps;
- if (compute_capability >= CC_RDNA2) {
- mmq_x = MMQ_X_Q5_1_RDNA2;
- mmq_y = MMQ_Y_Q5_1_RDNA2;
- nwarps = NWARPS_Q5_1_RDNA2;
- } else if (compute_capability >= CC_OFFSET_AMD) {
- mmq_x = MMQ_X_Q5_1_RDNA1;
- mmq_y = MMQ_Y_Q5_1_RDNA1;
- nwarps = NWARPS_Q5_1_RDNA1;
- } else if (compute_capability >= CC_VOLTA) {
- mmq_x = MMQ_X_Q5_1_AMPERE;
- mmq_y = MMQ_Y_Q5_1_AMPERE;
- nwarps = NWARPS_Q5_1_AMPERE;
- } else if (compute_capability >= MIN_CC_DP4A) {
- mmq_x = MMQ_X_Q5_1_PASCAL;
- mmq_y = MMQ_Y_Q5_1_PASCAL;
- nwarps = NWARPS_Q5_1_PASCAL;
- } else {
- GGML_ASSERT(false);
- }
- const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
- const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
- const dim3 block_nums(block_num_x, block_num_y, 1);
- const dim3 block_dims(WARP_SIZE, nwarps, 1);
- if (nrows_x % mmq_y == 0) {
- const bool need_check = false;
- mul_mat_q5_1<need_check><<<block_nums, block_dims, 0, stream>>>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- } else {
- const bool need_check = true;
- mul_mat_q5_1<need_check><<<block_nums, block_dims, 0, stream>>>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- }
- }
- static void ggml_mul_mat_q8_0_q8_1_cuda(
- const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
- const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
- const int compute_capability = ggml_cuda_info().devices[id].cc;
- int mmq_x, mmq_y, nwarps;
- if (compute_capability >= CC_RDNA2) {
- mmq_x = MMQ_X_Q8_0_RDNA2;
- mmq_y = MMQ_Y_Q8_0_RDNA2;
- nwarps = NWARPS_Q8_0_RDNA2;
- } else if (compute_capability >= CC_OFFSET_AMD) {
- mmq_x = MMQ_X_Q8_0_RDNA1;
- mmq_y = MMQ_Y_Q8_0_RDNA1;
- nwarps = NWARPS_Q8_0_RDNA1;
- } else if (compute_capability >= CC_VOLTA) {
- mmq_x = MMQ_X_Q8_0_AMPERE;
- mmq_y = MMQ_Y_Q8_0_AMPERE;
- nwarps = NWARPS_Q8_0_AMPERE;
- } else if (compute_capability >= MIN_CC_DP4A) {
- mmq_x = MMQ_X_Q8_0_PASCAL;
- mmq_y = MMQ_Y_Q8_0_PASCAL;
- nwarps = NWARPS_Q8_0_PASCAL;
- } else {
- GGML_ASSERT(false);
- }
- const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
- const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
- const dim3 block_nums(block_num_x, block_num_y, 1);
- const dim3 block_dims(WARP_SIZE, nwarps, 1);
- if (nrows_x % mmq_y == 0) {
- const bool need_check = false;
- mul_mat_q8_0<need_check><<<block_nums, block_dims, 0, stream>>>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- } else {
- const bool need_check = true;
- mul_mat_q8_0<need_check><<<block_nums, block_dims, 0, stream>>>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- }
- }
- static void ggml_mul_mat_q2_K_q8_1_cuda(
- const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
- const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
- const int compute_capability = ggml_cuda_info().devices[id].cc;
- int mmq_x, mmq_y, nwarps;
- if (compute_capability >= CC_RDNA2) {
- mmq_x = MMQ_X_Q2_K_RDNA2;
- mmq_y = MMQ_Y_Q2_K_RDNA2;
- nwarps = NWARPS_Q2_K_RDNA2;
- } else if (compute_capability >= CC_OFFSET_AMD) {
- mmq_x = MMQ_X_Q2_K_RDNA1;
- mmq_y = MMQ_Y_Q2_K_RDNA1;
- nwarps = NWARPS_Q2_K_RDNA1;
- } else if (compute_capability >= CC_VOLTA) {
- mmq_x = MMQ_X_Q2_K_AMPERE;
- mmq_y = MMQ_Y_Q2_K_AMPERE;
- nwarps = NWARPS_Q2_K_AMPERE;
- } else if (compute_capability >= MIN_CC_DP4A) {
- mmq_x = MMQ_X_Q2_K_PASCAL;
- mmq_y = MMQ_Y_Q2_K_PASCAL;
- nwarps = NWARPS_Q2_K_PASCAL;
- } else {
- GGML_ASSERT(false);
- }
- const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
- const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
- const dim3 block_nums(block_num_x, block_num_y, 1);
- const dim3 block_dims(WARP_SIZE, nwarps, 1);
- if (nrows_x % mmq_y == 0) {
- const bool need_check = false;
- mul_mat_q2_K<need_check><<<block_nums, block_dims, 0, stream>>>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- } else {
- const bool need_check = true;
- mul_mat_q2_K<need_check><<<block_nums, block_dims, 0, stream>>>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- }
- }
- static void ggml_mul_mat_q3_K_q8_1_cuda(
- const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
- const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- #if QK_K == 256
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
- const int compute_capability = ggml_cuda_info().devices[id].cc;
- int mmq_x, mmq_y, nwarps;
- if (compute_capability >= CC_RDNA2) {
- mmq_x = MMQ_X_Q3_K_RDNA2;
- mmq_y = MMQ_Y_Q3_K_RDNA2;
- nwarps = NWARPS_Q3_K_RDNA2;
- } else if (compute_capability >= CC_OFFSET_AMD) {
- mmq_x = MMQ_X_Q3_K_RDNA1;
- mmq_y = MMQ_Y_Q3_K_RDNA1;
- nwarps = NWARPS_Q3_K_RDNA1;
- } else if (compute_capability >= CC_VOLTA) {
- mmq_x = MMQ_X_Q3_K_AMPERE;
- mmq_y = MMQ_Y_Q3_K_AMPERE;
- nwarps = NWARPS_Q3_K_AMPERE;
- } else if (compute_capability >= MIN_CC_DP4A) {
- mmq_x = MMQ_X_Q3_K_PASCAL;
- mmq_y = MMQ_Y_Q3_K_PASCAL;
- nwarps = NWARPS_Q3_K_PASCAL;
- } else {
- GGML_ASSERT(false);
- }
- const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
- const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
- const dim3 block_nums(block_num_x, block_num_y, 1);
- const dim3 block_dims(WARP_SIZE, nwarps, 1);
- if (nrows_x % mmq_y == 0) {
- const bool need_check = false;
- mul_mat_q3_K<need_check><<<block_nums, block_dims, 0, stream>>>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- } else {
- const bool need_check = true;
- mul_mat_q3_K<need_check><<<block_nums, block_dims, 0, stream>>>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- }
- #endif
- }
- static void ggml_mul_mat_q4_K_q8_1_cuda(
- const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
- const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
- const int compute_capability = ggml_cuda_info().devices[id].cc;
- int mmq_x, mmq_y, nwarps;
- if (compute_capability >= CC_RDNA2) {
- mmq_x = MMQ_X_Q4_K_RDNA2;
- mmq_y = MMQ_Y_Q4_K_RDNA2;
- nwarps = NWARPS_Q4_K_RDNA2;
- } else if (compute_capability >= CC_OFFSET_AMD) {
- mmq_x = MMQ_X_Q4_K_RDNA1;
- mmq_y = MMQ_Y_Q4_K_RDNA1;
- nwarps = NWARPS_Q4_K_RDNA1;
- } else if (compute_capability >= CC_VOLTA) {
- mmq_x = MMQ_X_Q4_K_AMPERE;
- mmq_y = MMQ_Y_Q4_K_AMPERE;
- nwarps = NWARPS_Q4_K_AMPERE;
- } else if (compute_capability >= MIN_CC_DP4A) {
- mmq_x = MMQ_X_Q4_K_PASCAL;
- mmq_y = MMQ_Y_Q4_K_PASCAL;
- nwarps = NWARPS_Q4_K_PASCAL;
- } else {
- GGML_ASSERT(false);
- }
- const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
- const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
- const dim3 block_nums(block_num_x, block_num_y, 1);
- const dim3 block_dims(WARP_SIZE, nwarps, 1);
- if (nrows_x % mmq_y == 0) {
- const bool need_check = false;
- mul_mat_q4_K<need_check><<<block_nums, block_dims, 0, stream>>>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- } else {
- const bool need_check = true;
- mul_mat_q4_K<need_check><<<block_nums, block_dims, 0, stream>>>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- }
- }
- static void ggml_mul_mat_q5_K_q8_1_cuda(
- const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
- const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
- const int compute_capability = ggml_cuda_info().devices[id].cc;
- int mmq_x, mmq_y, nwarps;
- if (compute_capability >= CC_RDNA2) {
- mmq_x = MMQ_X_Q5_K_RDNA2;
- mmq_y = MMQ_Y_Q5_K_RDNA2;
- nwarps = NWARPS_Q5_K_RDNA2;
- } else if (compute_capability >= CC_OFFSET_AMD) {
- mmq_x = MMQ_X_Q5_K_RDNA1;
- mmq_y = MMQ_Y_Q5_K_RDNA1;
- nwarps = NWARPS_Q5_K_RDNA1;
- } else if (compute_capability >= CC_VOLTA) {
- mmq_x = MMQ_X_Q5_K_AMPERE;
- mmq_y = MMQ_Y_Q5_K_AMPERE;
- nwarps = NWARPS_Q5_K_AMPERE;
- } else if (compute_capability >= MIN_CC_DP4A) {
- mmq_x = MMQ_X_Q5_K_PASCAL;
- mmq_y = MMQ_Y_Q5_K_PASCAL;
- nwarps = NWARPS_Q5_K_PASCAL;
- } else {
- GGML_ASSERT(false);
- }
- const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
- const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
- const dim3 block_nums(block_num_x, block_num_y, 1);
- const dim3 block_dims(WARP_SIZE, nwarps, 1);
- if (nrows_x % mmq_y == 0) {
- const bool need_check = false;
- mul_mat_q5_K<need_check><<<block_nums, block_dims, 0, stream>>>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- } else {
- const bool need_check = true;
- mul_mat_q5_K<need_check><<<block_nums, block_dims, 0, stream>>>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- }
- }
- static void ggml_mul_mat_q6_K_q8_1_cuda(
- const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
- const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
- const int compute_capability = ggml_cuda_info().devices[id].cc;
- int mmq_x, mmq_y, nwarps;
- if (compute_capability >= CC_RDNA2) {
- mmq_x = MMQ_X_Q6_K_RDNA2;
- mmq_y = MMQ_Y_Q6_K_RDNA2;
- nwarps = NWARPS_Q6_K_RDNA2;
- } else if (compute_capability >= CC_OFFSET_AMD) {
- mmq_x = MMQ_X_Q6_K_RDNA1;
- mmq_y = MMQ_Y_Q6_K_RDNA1;
- nwarps = NWARPS_Q6_K_RDNA1;
- } else if (compute_capability >= CC_VOLTA) {
- mmq_x = MMQ_X_Q6_K_AMPERE;
- mmq_y = MMQ_Y_Q6_K_AMPERE;
- nwarps = NWARPS_Q6_K_AMPERE;
- } else if (compute_capability >= MIN_CC_DP4A) {
- mmq_x = MMQ_X_Q6_K_PASCAL;
- mmq_y = MMQ_Y_Q6_K_PASCAL;
- nwarps = NWARPS_Q6_K_PASCAL;
- } else {
- GGML_ASSERT(false);
- }
- const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
- const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
- const dim3 block_nums(block_num_x, block_num_y, 1);
- const dim3 block_dims(WARP_SIZE, nwarps, 1);
- if (nrows_x % mmq_y == 0) {
- const bool need_check = false;
- mul_mat_q6_K<need_check><<<block_nums, block_dims, 0, stream>>>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- } else {
- const bool need_check = true;
- mul_mat_q6_K<need_check><<<block_nums, block_dims, 0, stream>>>
- (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
- }
- }
- void ggml_cuda_op_mul_mat_q(
- ggml_backend_cuda_context & ctx,
- const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
- const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
- const int64_t src1_padded_row_size, cudaStream_t stream) {
- const int64_t ne00 = src0->ne[0];
- const int64_t ne10 = src1->ne[0];
- GGML_ASSERT(ne10 % QK8_1 == 0);
- const int64_t ne0 = dst->ne[0];
- const int64_t row_diff = row_high - row_low;
- int id = ggml_cuda_get_device();
- // the main device has a larger memory buffer to hold the results from all GPUs
- // nrows_dst == nrows of the matrix that the kernel writes into
- const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
- switch (src0->type) {
- case GGML_TYPE_Q4_0:
- ggml_mul_mat_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
- break;
- case GGML_TYPE_Q4_1:
- ggml_mul_mat_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
- break;
- case GGML_TYPE_Q5_0:
- ggml_mul_mat_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
- break;
- case GGML_TYPE_Q5_1:
- ggml_mul_mat_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
- break;
- case GGML_TYPE_Q8_0:
- ggml_mul_mat_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
- break;
- case GGML_TYPE_Q2_K:
- ggml_mul_mat_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
- break;
- case GGML_TYPE_Q3_K:
- ggml_mul_mat_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
- break;
- case GGML_TYPE_Q4_K:
- ggml_mul_mat_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
- break;
- case GGML_TYPE_Q5_K:
- ggml_mul_mat_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
- break;
- case GGML_TYPE_Q6_K:
- ggml_mul_mat_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
- break;
- default:
- GGML_ASSERT(false);
- break;
- }
- GGML_UNUSED(src1);
- GGML_UNUSED(dst);
- GGML_UNUSED(src1_ddf_i);
- }
- bool ggml_cuda_supports_mmq(enum ggml_type type) {
- switch (type) {
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1:
- case GGML_TYPE_Q5_0:
- case GGML_TYPE_Q5_1:
- case GGML_TYPE_Q8_0:
- case GGML_TYPE_Q2_K:
- case GGML_TYPE_Q3_K:
- case GGML_TYPE_Q4_K:
- case GGML_TYPE_Q5_K:
- case GGML_TYPE_Q6_K:
- return true;
- default:
- return false;
- }
- }
|