ggml-cuda.cu 113 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050
  1. #include <cstddef>
  2. #include <cstdint>
  3. #include <limits>
  4. #include <stdint.h>
  5. #include <stdio.h>
  6. #include <atomic>
  7. #include <assert.h>
  8. #include <cuda_runtime.h>
  9. #include <cublas_v2.h>
  10. #include <cuda_fp16.h>
  11. #include "ggml-cuda.h"
  12. #include "ggml.h"
  13. #if defined(_MSC_VER)
  14. #pragma warning(disable: 4244 4267) // possible loss of data
  15. #endif
  16. static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
  17. #define CUDA_CHECK(err) \
  18. do { \
  19. cudaError_t err_ = (err); \
  20. if (err_ != cudaSuccess) { \
  21. fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
  22. cudaGetErrorString(err_)); \
  23. exit(1); \
  24. } \
  25. } while (0)
  26. #if CUDART_VERSION >= 12000
  27. #define CUBLAS_CHECK(err) \
  28. do { \
  29. cublasStatus_t err_ = (err); \
  30. if (err_ != CUBLAS_STATUS_SUCCESS) { \
  31. fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \
  32. err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \
  33. exit(1); \
  34. } \
  35. } while (0)
  36. #else
  37. #define CUBLAS_CHECK(err) \
  38. do { \
  39. cublasStatus_t err_ = (err); \
  40. if (err_ != CUBLAS_STATUS_SUCCESS) { \
  41. fprintf(stderr, "\ncuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
  42. exit(1); \
  43. } \
  44. } while (0)
  45. #endif // CUDART_VERSION >= 11
  46. #ifdef GGML_CUDA_DMMV_F16
  47. typedef half dfloat; // dequantize float
  48. typedef half2 dfloat2;
  49. #else
  50. typedef float dfloat; // dequantize float
  51. typedef float2 dfloat2;
  52. #endif //GGML_CUDA_DMMV_F16
  53. typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
  54. typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
  55. typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v);
  56. typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
  57. typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
  58. typedef void (*ggml_cuda_op_t)(
  59. const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, float * src0_ddf_i,
  60. float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
  61. cudaStream_t & cudaStream_main);
  62. // QK = number of values after dequantization
  63. // QR = QK / number of values before dequantization
  64. #define QK4_0 32
  65. #define QR4_0 2
  66. typedef struct {
  67. half d; // delta
  68. uint8_t qs[QK4_0 / 2]; // nibbles / quants
  69. } block_q4_0;
  70. static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
  71. #define QK4_1 32
  72. #define QR4_1 2
  73. typedef struct {
  74. half d; // delta
  75. half m; // min
  76. uint8_t qs[QK4_1 / 2]; // nibbles / quants
  77. } block_q4_1;
  78. static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
  79. #define QK5_0 32
  80. #define QR5_0 2
  81. typedef struct {
  82. half d; // delta
  83. uint8_t qh[4]; // 5-th bit of quants
  84. uint8_t qs[QK5_0 / 2]; // nibbles / quants
  85. } block_q5_0;
  86. static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
  87. #define QK5_1 32
  88. #define QR5_1 2
  89. typedef struct {
  90. half d; // delta
  91. half m; // min
  92. uint8_t qh[4]; // 5-th bit of quants
  93. uint8_t qs[QK5_1 / 2]; // nibbles / quants
  94. } block_q5_1;
  95. static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
  96. #define QK8_0 32
  97. #define QR8_0 1
  98. typedef struct {
  99. half d; // delta
  100. int8_t qs[QK8_0]; // quants
  101. } block_q8_0;
  102. static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
  103. //================================= k-quants
  104. #ifdef GGML_QKK_64
  105. #define QK_K 64
  106. #define K_SCALE_SIZE 4
  107. #else
  108. #define QK_K 256
  109. #define K_SCALE_SIZE 12
  110. #endif
  111. typedef struct {
  112. uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
  113. uint8_t qs[QK_K/4]; // quants
  114. half d; // super-block scale for quantized scales
  115. half dmin; // super-block scale for quantized mins
  116. } block_q2_K;
  117. static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
  118. typedef struct {
  119. uint8_t hmask[QK_K/8]; // quants - high bit
  120. uint8_t qs[QK_K/4]; // quants - low 2 bits
  121. #ifdef GGML_QKK_64
  122. uint8_t scales[2]; // scales, quantized with 8 bits
  123. #else
  124. uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
  125. #endif
  126. half d; // super-block scale
  127. } block_q3_K;
  128. //static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding");
  129. #ifdef GGML_QKK_64
  130. typedef struct {
  131. half d[2]; // super-block scales/mins
  132. uint8_t scales[2]; // 4-bit block scales/mins
  133. uint8_t qs[QK_K/2]; // 4--bit quants
  134. } block_q4_K;
  135. static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
  136. #else
  137. typedef struct {
  138. half d; // super-block scale for quantized scales
  139. half dmin; // super-block scale for quantized mins
  140. uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
  141. uint8_t qs[QK_K/2]; // 4--bit quants
  142. } block_q4_K;
  143. static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
  144. #endif
  145. #ifdef GGML_QKK_64
  146. typedef struct {
  147. half d; // super-block scale
  148. int8_t scales[QK_K/16]; // block scales
  149. uint8_t qh[QK_K/8]; // quants, high bit
  150. uint8_t qs[QK_K/2]; // quants, low 4 bits
  151. } block_q5_K;
  152. static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
  153. #else
  154. typedef struct {
  155. half d; // super-block scale for quantized scales
  156. half dmin; // super-block scale for quantized mins
  157. uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
  158. uint8_t qh[QK_K/8]; // quants, high bit
  159. uint8_t qs[QK_K/2]; // quants, low 4 bits
  160. } block_q5_K;
  161. static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
  162. #endif
  163. typedef struct {
  164. uint8_t ql[QK_K/2]; // quants, lower 4 bits
  165. uint8_t qh[QK_K/4]; // quants, upper 2 bits
  166. int8_t scales[QK_K/16]; // scales
  167. half d; // delta
  168. } block_q6_K;
  169. static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
  170. #define WARP_SIZE 32
  171. #define CUDA_ADD_BLOCK_SIZE 256
  172. #define CUDA_MUL_BLOCK_SIZE 256
  173. #define CUDA_SILU_BLOCK_SIZE 256
  174. #define CUDA_CPY_BLOCK_SIZE 32
  175. #define CUDA_SCALE_BLOCK_SIZE 256
  176. #define CUDA_ROPE_BLOCK_SIZE 256
  177. #define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
  178. #define CUDA_DEQUANTIZE_BLOCK_SIZE 256
  179. // dmmv = dequantize_mul_mat_vec
  180. #ifndef GGML_CUDA_DMMV_X
  181. #define GGML_CUDA_DMMV_X 32
  182. #endif
  183. #ifndef GGML_CUDA_DMMV_Y
  184. #define GGML_CUDA_DMMV_Y 1
  185. #endif
  186. #ifndef K_QUANTS_PER_ITERATION
  187. #define K_QUANTS_PER_ITERATION 2
  188. #else
  189. static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
  190. #endif
  191. struct ggml_tensor_extra_gpu {
  192. void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
  193. cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs
  194. };
  195. static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) {
  196. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  197. if (i >= k) {
  198. return;
  199. }
  200. dst[i] = x[i] + y[i];
  201. }
  202. static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) {
  203. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  204. if (i >= k) {
  205. return;
  206. }
  207. dst[i] = __hadd(x[i], __float2half(y[i]));
  208. }
  209. static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
  210. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  211. if (i >= kx) {
  212. return;
  213. }
  214. dst[i] = x[i] * y[i%ky];
  215. }
  216. static __global__ void silu_f32(const float * x, float * dst, const int k) {
  217. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  218. if (i >= k) {
  219. return;
  220. }
  221. dst[i] = x[i] / (1.0f + expf(-x[i]));
  222. }
  223. static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) {
  224. const int row = blockIdx.x*blockDim.y + threadIdx.y;
  225. const int tid = threadIdx.x;
  226. const float eps = 1e-6;
  227. float tmp = 0.0f; // partial sum for thread in warp
  228. for (int i = 0; i < ncols; i += WARP_SIZE) {
  229. const int col = i + tid;
  230. const float xi = x[row*ncols + col];
  231. tmp += xi * xi;
  232. }
  233. // sum up partial sums
  234. __syncthreads();
  235. #pragma unroll
  236. for (int mask = 16; mask > 0; mask >>= 1) {
  237. tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
  238. }
  239. const float mean = tmp / ncols;
  240. const float scale = 1.0f / sqrtf(mean + eps);
  241. for (int i = 0; i < ncols; i += WARP_SIZE) {
  242. const int col = i + tid;
  243. dst[row*ncols + col] = scale * x[row*ncols + col];
  244. }
  245. }
  246. static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
  247. const block_q4_0 * x = (const block_q4_0 *) vx;
  248. const dfloat d = x[ib].d;
  249. const int vui = x[ib].qs[iqs];
  250. v.x = vui & 0xF;
  251. v.y = vui >> 4;
  252. #ifdef GGML_CUDA_DMMV_F16
  253. v = __hsub2(v, {8.0f, 8.0f});
  254. v = __hmul2(v, {d, d});
  255. #else
  256. v.x = (v.x - 8.0f) * d;
  257. v.y = (v.y - 8.0f) * d;
  258. #endif // GGML_CUDA_DMMV_F16
  259. }
  260. static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
  261. const block_q4_1 * x = (const block_q4_1 *) vx;
  262. const dfloat d = x[ib].d;
  263. const dfloat m = x[ib].m;
  264. const int vui = x[ib].qs[iqs];
  265. v.x = vui & 0xF;
  266. v.y = vui >> 4;
  267. #ifdef GGML_CUDA_DMMV_F16
  268. v = __hmul2(v, {d, d});
  269. v = __hadd2(v, {m, m});
  270. #else
  271. v.x = (v.x * d) + m;
  272. v.y = (v.y * d) + m;
  273. #endif // GGML_CUDA_DMMV_F16
  274. }
  275. static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
  276. const block_q5_0 * x = (const block_q5_0 *) vx;
  277. const dfloat d = x[ib].d;
  278. uint32_t qh;
  279. memcpy(&qh, x[ib].qh, sizeof(qh));
  280. const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
  281. const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
  282. v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
  283. v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
  284. #ifdef GGML_CUDA_DMMV_F16
  285. v = __hsub2(v, {16.0f, 16.0f});
  286. v = __hmul2(v, {d, d});
  287. #else
  288. v.x = (v.x - 16.0f) * d;
  289. v.y = (v.y - 16.0f) * d;
  290. #endif // GGML_CUDA_DMMV_F16
  291. }
  292. static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
  293. const block_q5_1 * x = (const block_q5_1 *) vx;
  294. const dfloat d = x[ib].d;
  295. const dfloat m = x[ib].m;
  296. uint32_t qh;
  297. memcpy(&qh, x[ib].qh, sizeof(qh));
  298. const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
  299. const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
  300. v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
  301. v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
  302. #ifdef GGML_CUDA_DMMV_F16
  303. v = __hmul2(v, {d, d});
  304. v = __hadd2(v, {m, m});
  305. #else
  306. v.x = (v.x * d) + m;
  307. v.y = (v.y * d) + m;
  308. #endif // GGML_CUDA_DMMV_F16
  309. }
  310. static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
  311. const block_q8_0 * x = (const block_q8_0 *) vx;
  312. const dfloat d = x[ib].d;
  313. v.x = x[ib].qs[iqs + 0];
  314. v.y = x[ib].qs[iqs + 1];
  315. #ifdef GGML_CUDA_DMMV_F16
  316. v = __hmul2(v, {d, d});
  317. #else
  318. v.x *= d;
  319. v.y *= d;
  320. #endif // GGML_CUDA_DMMV_F16
  321. }
  322. //================================== k-quants
  323. static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
  324. const int i = blockIdx.x;
  325. const block_q2_K * x = (const block_q2_K *) vx;
  326. const int tid = threadIdx.x;
  327. #if QK_K == 256
  328. const int n = tid/32;
  329. const int l = tid - 32*n;
  330. const int is = 8*n + l/16;
  331. const uint8_t q = x[i].qs[32*n + l];
  332. float * y = yy + i*QK_K + 128*n;
  333. float dall = x[i].d;
  334. float dmin = x[i].dmin;
  335. y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
  336. y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
  337. y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
  338. y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
  339. #else
  340. const int is = tid/16; // 0 or 1
  341. const int il = tid%16; // 0...15
  342. const uint8_t q = x[i].qs[il] >> (2*is);
  343. float * y = yy + i*QK_K + 16*is + il;
  344. float dall = x[i].d;
  345. float dmin = x[i].dmin;
  346. y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
  347. y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
  348. #endif
  349. }
  350. static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
  351. const int i = blockIdx.x;
  352. const block_q3_K * x = (const block_q3_K *) vx;
  353. #if QK_K == 256
  354. const int r = threadIdx.x/4;
  355. const int tid = r/2;
  356. const int is0 = r%2;
  357. const int l0 = 16*is0 + 4*(threadIdx.x%4);
  358. const int n = tid / 4;
  359. const int j = tid - 4*n;
  360. uint8_t m = 1 << (4*n + j);
  361. int is = 8*n + 2*j + is0;
  362. int shift = 2*j;
  363. int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
  364. is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :
  365. is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :
  366. (x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4);
  367. float d_all = x[i].d;
  368. float dl = d_all * (us - 32);
  369. float * y = yy + i*QK_K + 128*n + 32*j;
  370. const uint8_t * q = x[i].qs + 32*n;
  371. const uint8_t * hm = x[i].hmask;
  372. for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
  373. #else
  374. const int tid = threadIdx.x;
  375. const int is = tid/16; // 0 or 1
  376. const int il = tid%16; // 0...15
  377. const int im = il/8; // 0...1
  378. const int in = il%8; // 0...7
  379. float * y = yy + i*QK_K + 16*is + il;
  380. const uint8_t q = x[i].qs[il] >> (2*is);
  381. const uint8_t h = x[i].hmask[in] >> (2*is + im);
  382. const float d = (float)x[i].d;
  383. if (is == 0) {
  384. y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
  385. y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
  386. } else {
  387. y[ 0] = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
  388. y[32] = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
  389. }
  390. #endif
  391. }
  392. #if QK_K == 256
  393. static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
  394. if (j < 4) {
  395. d = q[j] & 63; m = q[j + 4] & 63;
  396. } else {
  397. d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
  398. m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
  399. }
  400. }
  401. #endif
  402. static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
  403. const block_q4_K * x = (const block_q4_K *) vx;
  404. const int i = blockIdx.x;
  405. #if QK_K == 256
  406. // assume 32 threads
  407. const int tid = threadIdx.x;
  408. const int il = tid/8;
  409. const int ir = tid%8;
  410. const int is = 2*il;
  411. const int n = 4;
  412. float * y = yy + i*QK_K + 64*il + n*ir;
  413. const float dall = x[i].d;
  414. const float dmin = x[i].dmin;
  415. const uint8_t * q = x[i].qs + 32*il + n*ir;
  416. uint8_t sc, m;
  417. get_scale_min_k4(is + 0, x[i].scales, sc, m);
  418. const float d1 = dall * sc; const float m1 = dmin * m;
  419. get_scale_min_k4(is + 1, x[i].scales, sc, m);
  420. const float d2 = dall * sc; const float m2 = dmin * m;
  421. for (int l = 0; l < n; ++l) {
  422. y[l + 0] = d1 * (q[l] & 0xF) - m1;
  423. y[l +32] = d2 * (q[l] >> 4) - m2;
  424. }
  425. #else
  426. const int tid = threadIdx.x;
  427. const uint8_t * q = x[i].qs;
  428. float * y = yy + i*QK_K;
  429. const float d = (float)x[i].d[0];
  430. const float m = (float)x[i].d[1];
  431. y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
  432. y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4);
  433. #endif
  434. }
  435. static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
  436. const block_q5_K * x = (const block_q5_K *) vx;
  437. const int i = blockIdx.x;
  438. #if QK_K == 256
  439. // assume 64 threads - this is very slightly better than the one below
  440. const int tid = threadIdx.x;
  441. const int il = tid/16; // il is in 0...3
  442. const int ir = tid%16; // ir is in 0...15
  443. const int is = 2*il; // is is in 0...6
  444. float * y = yy + i*QK_K + 64*il + 2*ir;
  445. const float dall = x[i].d;
  446. const float dmin = x[i].dmin;
  447. const uint8_t * ql = x[i].qs + 32*il + 2*ir;
  448. const uint8_t * qh = x[i].qh + 2*ir;
  449. uint8_t sc, m;
  450. get_scale_min_k4(is + 0, x[i].scales, sc, m);
  451. const float d1 = dall * sc; const float m1 = dmin * m;
  452. get_scale_min_k4(is + 1, x[i].scales, sc, m);
  453. const float d2 = dall * sc; const float m2 = dmin * m;
  454. uint8_t hm = 1 << (2*il);
  455. y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1;
  456. y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1;
  457. hm <<= 1;
  458. y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
  459. y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
  460. #else
  461. const int tid = threadIdx.x;
  462. const uint8_t q = x[i].qs[tid];
  463. const int im = tid/8; // 0...3
  464. const int in = tid%8; // 0...7
  465. const int is = tid/16; // 0 or 1
  466. const uint8_t h = x[i].qh[in] >> im;
  467. const float d = x[i].d;
  468. float * y = yy + i*QK_K + tid;
  469. y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
  470. y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16));
  471. #endif
  472. }
  473. static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
  474. const block_q6_K * x = (const block_q6_K *) vx;
  475. const int i = blockIdx.x;
  476. #if QK_K == 256
  477. // assume 64 threads - this is very slightly better than the one below
  478. const int tid = threadIdx.x;
  479. const int ip = tid/32; // ip is 0 or 1
  480. const int il = tid - 32*ip; // 0...32
  481. const int is = 8*ip + il/16;
  482. float * y = yy + i*QK_K + 128*ip + il;
  483. const float d = x[i].d;
  484. const uint8_t * ql = x[i].ql + 64*ip + il;
  485. const uint8_t qh = x[i].qh[32*ip + il];
  486. const int8_t * sc = x[i].scales + is;
  487. y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
  488. y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
  489. y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
  490. y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
  491. #else
  492. // assume 32 threads
  493. const int tid = threadIdx.x;
  494. const int ip = tid/16; // 0 or 1
  495. const int il = tid - 16*ip; // 0...15
  496. float * y = yy + i*QK_K + 16*ip + il;
  497. const float d = x[i].d;
  498. const uint8_t ql = x[i].ql[16*ip + il];
  499. const uint8_t qh = x[i].qh[il] >> (2*ip);
  500. const int8_t * sc = x[i].scales;
  501. y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
  502. y[32] = d * sc[ip+2] * ((int8_t)((ql >> 4) | (((qh >> 4) & 3) << 4)) - 32);
  503. #endif
  504. }
  505. static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
  506. static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
  507. const int row = blockIdx.y*blockDim.y + threadIdx.y;
  508. if (row > nrows) return;
  509. const int num_blocks_per_row = ncols / QK_K;
  510. const int ib0 = row*num_blocks_per_row;
  511. const block_q2_K * x = (const block_q2_K *)vx + ib0;
  512. float tmp = 0; // partial sum for thread in warp
  513. #if QK_K == 256
  514. const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15
  515. const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
  516. const int step = 16/K_QUANTS_PER_ITERATION;
  517. const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
  518. const int in = tid - step*im; // 0...15 or 0...7
  519. const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2
  520. const int q_offset = 32*im + l0;
  521. const int s_offset = 8*im;
  522. const int y_offset = 128*im + l0;
  523. uint32_t aux[4];
  524. const uint8_t * d = (const uint8_t *)aux;
  525. const uint8_t * m = (const uint8_t *)(aux + 2);
  526. for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
  527. const float * y = yy + i * QK_K + y_offset;
  528. const uint8_t * q = x[i].qs + q_offset;
  529. const float dall = x[i].d;
  530. const float dmin = x[i].dmin;
  531. const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
  532. aux[0] = a[0] & 0x0f0f0f0f;
  533. aux[1] = a[1] & 0x0f0f0f0f;
  534. aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
  535. aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
  536. float sum1 = 0, sum2 = 0;
  537. for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
  538. sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
  539. + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
  540. + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
  541. + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)
  542. + y[l+16] * d[1] * ((q[l+16] >> 0) & 3)
  543. + y[l+48] * d[3] * ((q[l+16] >> 2) & 3)
  544. + y[l+80] * d[5] * ((q[l+16] >> 4) & 3)
  545. +y[l+112] * d[7] * ((q[l+16] >> 6) & 3);
  546. sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]
  547. + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];
  548. }
  549. tmp += dall * sum1 - dmin * sum2;
  550. }
  551. #else
  552. const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7
  553. const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3
  554. const int offset = tid * K_QUANTS_PER_ITERATION;
  555. uint32_t uaux[2];
  556. const uint8_t * d = (const uint8_t *)uaux;
  557. for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
  558. const float * y = yy + i * QK_K + offset;
  559. const uint8_t * q = x[i].qs + offset;
  560. const uint32_t * s = (const uint32_t *)x[i].scales;
  561. uaux[0] = s[0] & 0x0f0f0f0f;
  562. uaux[1] = (s[0] >> 4) & 0x0f0f0f0f;
  563. const half2 * dh = (const half2 *)&x[i].d;
  564. const float2 dall = __half22float2(dh[0]);
  565. float sum1 = 0, sum2 = 0;
  566. for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
  567. const uint8_t ql = q[l];
  568. sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3)
  569. + y[l+16] * d[1] * ((ql >> 2) & 3)
  570. + y[l+32] * d[2] * ((ql >> 4) & 3)
  571. + y[l+48] * d[3] * ((ql >> 6) & 3);
  572. sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7];
  573. }
  574. tmp += dall.x * sum1 - dall.y * sum2;
  575. }
  576. #endif
  577. // sum up partial sums and write back result
  578. __syncthreads();
  579. #pragma unroll
  580. for (int mask = 16; mask > 0; mask >>= 1) {
  581. tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
  582. }
  583. if (threadIdx.x == 0) {
  584. dst[row] = tmp;
  585. }
  586. }
  587. static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
  588. const int row = blockIdx.y*blockDim.y + threadIdx.y;
  589. if (row > nrows) return;
  590. const int num_blocks_per_row = ncols / QK_K;
  591. const int ib0 = row*num_blocks_per_row;
  592. const block_q3_K * x = (const block_q3_K *)vx + ib0;
  593. float tmp = 0; // partial sum for thread in warp
  594. #if QK_K == 256
  595. const uint16_t kmask1 = 0x0303;
  596. const uint16_t kmask2 = 0x0f0f;
  597. const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
  598. const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
  599. const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop
  600. const int step = 16/K_QUANTS_PER_ITERATION;
  601. const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
  602. const int in = tid - step*im; // 0....15 or 0...7
  603. const uint8_t m = 1 << (4*im);
  604. const int l0 = n*in; // 0...15 or 0...14 in steps of 2
  605. const int q_offset = 32*im + l0;
  606. const int y_offset = 128*im + l0;
  607. uint16_t utmp[4];
  608. const int8_t * s = (const int8_t *)utmp;
  609. const uint16_t s_shift = 4*im;
  610. for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
  611. const float * y = yy + i * QK_K + y_offset;
  612. const uint8_t * q = x[i].qs + q_offset;
  613. const uint8_t * h = x[i].hmask + l0;
  614. const uint16_t * a = (const uint16_t *)x[i].scales;
  615. utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
  616. utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
  617. utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
  618. utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);
  619. const float d = x[i].d;
  620. float sum = 0;
  621. for (int l = 0; l < n; ++l) {
  622. sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
  623. + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
  624. + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
  625. + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
  626. sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
  627. + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
  628. + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
  629. + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
  630. }
  631. tmp += d * sum;
  632. }
  633. #else
  634. const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7
  635. const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3
  636. const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14
  637. const int in = offset/8; // 0 or 1
  638. const int im = offset%8; // 0...7
  639. for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
  640. const float * y = yy + i * QK_K + offset;
  641. const uint8_t * q = x[i].qs + offset;
  642. const uint8_t * s = x[i].scales;
  643. const float dall = (float)x[i].d;
  644. float sum = 0;
  645. for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
  646. const uint8_t hl = x[i].hmask[im+l] >> in;
  647. const uint8_t ql = q[l];
  648. sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4))
  649. + y[l+16] * dall * ((s[0] >> 4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4))
  650. + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4))
  651. + y[l+48] * dall * ((s[1] >> 4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4));
  652. }
  653. tmp += sum;
  654. }
  655. #endif
  656. // sum up partial sums and write back result
  657. __syncthreads();
  658. #pragma unroll
  659. for (int mask = 16; mask > 0; mask >>= 1) {
  660. tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
  661. }
  662. if (threadIdx.x == 0) {
  663. dst[row] = tmp;
  664. }
  665. }
  666. static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
  667. const int row = blockIdx.y*blockDim.y + threadIdx.y;
  668. if (row > nrows) return;
  669. const int num_blocks_per_row = ncols / QK_K;
  670. const int ib0 = row*num_blocks_per_row;
  671. const block_q4_K * x = (const block_q4_K *)vx + ib0;
  672. #if QK_K == 256
  673. const uint16_t kmask1 = 0x3f3f;
  674. const uint16_t kmask2 = 0x0f0f;
  675. const uint16_t kmask3 = 0xc0c0;
  676. const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
  677. const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
  678. const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
  679. const int il = tid/step; // 0...3
  680. const int ir = tid - step*il; // 0...7 or 0...3
  681. const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4
  682. const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
  683. const int in = il%2;
  684. const int l0 = n*(2*ir + in);
  685. const int q_offset = 32*im + l0;
  686. const int y_offset = 64*im + l0;
  687. uint16_t aux[4];
  688. const uint8_t * sc = (const uint8_t *)aux;
  689. float tmp = 0; // partial sum for thread in warp
  690. for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
  691. const uint8_t * q1 = x[i].qs + q_offset;
  692. const uint8_t * q2 = q1 + 64;
  693. const float * y1 = yy + i*QK_K + y_offset;
  694. const float * y2 = y1 + 128;
  695. const float dall = x[i].d;
  696. const float dmin = x[i].dmin;
  697. const uint16_t * a = (const uint16_t *)x[i].scales;
  698. aux[0] = a[im+0] & kmask1;
  699. aux[1] = a[im+2] & kmask1;
  700. aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
  701. aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
  702. float4 s = {0.f, 0.f, 0.f, 0.f};
  703. float smin = 0;
  704. for (int l = 0; l < n; ++l) {
  705. s.x += y1[l] * (q1[l] & 0xF); s.y += y1[l+32] * (q1[l] >> 4);
  706. s.z += y2[l] * (q2[l] & 0xF); s.w += y2[l+32] * (q2[l] >> 4);
  707. smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
  708. }
  709. tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
  710. }
  711. #else
  712. const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15
  713. const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
  714. const int step = tid * K_QUANTS_PER_ITERATION;
  715. uint16_t aux16[2];
  716. const uint8_t * s = (const uint8_t *)aux16;
  717. float tmp = 0;
  718. for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
  719. const uint8_t * q = x[i].qs + step;
  720. const float * y = yy + i*QK_K + step;
  721. const uint16_t * a = (const uint16_t *)x[i].scales;
  722. aux16[0] = a[0] & 0x0f0f;
  723. aux16[1] = (a[0] >> 4) & 0x0f0f;
  724. const float d = (float)x[i].d[0];
  725. const float m = (float)x[i].d[1];
  726. float sum = 0.f;
  727. for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
  728. sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2])
  729. + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2])
  730. + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3])
  731. + y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]);
  732. }
  733. tmp += sum;
  734. }
  735. #endif
  736. // sum up partial sums and write back result
  737. __syncthreads();
  738. #pragma unroll
  739. for (int mask = 16; mask > 0; mask >>= 1) {
  740. tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
  741. }
  742. if (tid == 0) {
  743. dst[row] = tmp;
  744. }
  745. }
  746. static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) {
  747. const int row = blockIdx.x;
  748. const int num_blocks_per_row = ncols / QK_K;
  749. const int ib0 = row*num_blocks_per_row;
  750. const block_q5_K * x = (const block_q5_K *)vx + ib0;
  751. float tmp = 0; // partial sum for thread in warp
  752. #if QK_K == 256
  753. const uint16_t kmask1 = 0x3f3f;
  754. const uint16_t kmask2 = 0x0f0f;
  755. const uint16_t kmask3 = 0xc0c0;
  756. const int tid = threadIdx.x/2; // 0...15
  757. const int ix = threadIdx.x%2;
  758. const int il = tid/4; // 0...3
  759. const int ir = tid - 4*il;// 0...3
  760. const int n = 2;
  761. const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
  762. const int in = il%2;
  763. const int l0 = n*(2*ir + in);
  764. const int q_offset = 32*im + l0;
  765. const int y_offset = 64*im + l0;
  766. const uint8_t hm1 = 1 << (2*im);
  767. const uint8_t hm2 = hm1 << 4;
  768. uint16_t aux[4];
  769. const uint8_t * sc = (const uint8_t *)aux;
  770. for (int i = ix; i < num_blocks_per_row; i += 2) {
  771. const uint8_t * ql1 = x[i].qs + q_offset;
  772. const uint8_t * ql2 = ql1 + 64;
  773. const uint8_t * qh = x[i].qh + l0;
  774. const float * y1 = yy + i*QK_K + y_offset;
  775. const float * y2 = y1 + 128;
  776. const float dall = x[i].d;
  777. const float dmin = x[i].dmin;
  778. const uint16_t * a = (const uint16_t *)x[i].scales;
  779. aux[0] = a[im+0] & kmask1;
  780. aux[1] = a[im+2] & kmask1;
  781. aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
  782. aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
  783. float4 sum = {0.f, 0.f, 0.f, 0.f};
  784. float smin = 0;
  785. for (int l = 0; l < n; ++l) {
  786. sum.x += y1[l+ 0] * ((ql1[l+ 0] & 0xF) + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
  787. + y1[l+16] * ((ql1[l+16] & 0xF) + (qh[l+16] & (hm1 << 0) ? 16 : 0));
  788. sum.y += y1[l+32] * ((ql1[l+ 0] >> 4) + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
  789. + y1[l+48] * ((ql1[l+16] >> 4) + (qh[l+16] & (hm1 << 1) ? 16 : 0));
  790. sum.z += y2[l+ 0] * ((ql2[l+ 0] & 0xF) + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
  791. + y2[l+16] * ((ql2[l+16] & 0xF) + (qh[l+16] & (hm2 << 0) ? 16 : 0));
  792. sum.w += y2[l+32] * ((ql2[l+ 0] >> 4) + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
  793. + y2[l+48] * ((ql2[l+16] >> 4) + (qh[l+16] & (hm2 << 1) ? 16 : 0));
  794. smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
  795. + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
  796. }
  797. tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
  798. }
  799. #else
  800. const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15
  801. const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
  802. const int step = tid * K_QUANTS_PER_ITERATION;
  803. const int im = step/8;
  804. const int in = step%8;
  805. for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
  806. const uint8_t * q = x[i].qs + step;
  807. const int8_t * s = x[i].scales;
  808. const float * y = yy + i*QK_K + step;
  809. const float d = x[i].d;
  810. float sum = 0.f;
  811. for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
  812. const uint8_t h = x[i].qh[in+j] >> im;
  813. sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16))
  814. + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16))
  815. + y[j+32] * d * s[2] * ((q[j+ 0] >> 4) - ((h >> 4) & 1 ? 0 : 16))
  816. + y[j+48] * d * s[3] * ((q[j+16] >> 4) - ((h >> 6) & 1 ? 0 : 16));
  817. }
  818. tmp += sum;
  819. }
  820. #endif
  821. // sum up partial sums and write back result
  822. __syncthreads();
  823. #pragma unroll
  824. for (int mask = 16; mask > 0; mask >>= 1) {
  825. tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
  826. }
  827. if (threadIdx.x == 0) {
  828. dst[row] = tmp;
  829. }
  830. }
  831. static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
  832. static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
  833. const int row = blockIdx.y*blockDim.y + threadIdx.y;
  834. if (row > nrows) return;
  835. const int num_blocks_per_row = ncols / QK_K;
  836. const int ib0 = row*num_blocks_per_row;
  837. const block_q6_K * x = (const block_q6_K *)vx + ib0;
  838. #if QK_K == 256
  839. const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
  840. const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
  841. const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
  842. const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
  843. const int in = tid - step*im; // 0...15 or 0...7
  844. #if K_QUANTS_PER_ITERATION == 1
  845. const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
  846. const int is = 0;
  847. #else
  848. const int l0 = 4 * in; // 0, 4, 8, ..., 28
  849. const int is = in / 4;
  850. #endif
  851. const int ql_offset = 64*im + l0;
  852. const int qh_offset = 32*im + l0;
  853. const int s_offset = 8*im + is;
  854. const int y_offset = 128*im + l0;
  855. float tmp = 0; // partial sum for thread in warp
  856. for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
  857. const float * y = yy + i * QK_K + y_offset;
  858. const uint8_t * ql = x[i].ql + ql_offset;
  859. const uint8_t * qh = x[i].qh + qh_offset;
  860. const int8_t * s = x[i].scales + s_offset;
  861. const float d = x[i].d;
  862. #if K_QUANTS_PER_ITERATION == 1
  863. float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
  864. + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
  865. + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
  866. + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
  867. + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
  868. + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
  869. + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
  870. +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
  871. tmp += sum;
  872. #else
  873. float sum = 0;
  874. for (int l = 0; l < 4; ++l) {
  875. sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
  876. + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
  877. + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
  878. + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
  879. }
  880. tmp += sum;
  881. #endif
  882. }
  883. #else
  884. const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...7
  885. const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0...3
  886. const int step = tid * K_QUANTS_PER_ITERATION;
  887. float tmp = 0; // partial sum for thread in warp
  888. for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
  889. const float * y = yy + i * QK_K + step;
  890. const uint8_t * ql = x[i].ql + step;
  891. const uint8_t * qh = x[i].qh + step;
  892. const int8_t * s = x[i].scales;
  893. const float d = x[i+0].d;
  894. float sum = 0;
  895. for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
  896. sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32)
  897. + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32)
  898. + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32)
  899. + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32);
  900. }
  901. tmp += sum;
  902. }
  903. #endif
  904. // sum up partial sums and write back result
  905. __syncthreads();
  906. #pragma unroll
  907. for (int mask = 16; mask > 0; mask >>= 1) {
  908. tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
  909. }
  910. if (tid == 0) {
  911. dst[row] = tmp;
  912. }
  913. }
  914. static __device__ void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
  915. const half * x = (const half *) vx;
  916. // automatic half -> float type cast if dfloat == float
  917. v.x = x[ib + iqs + 0];
  918. v.y = x[ib + iqs + 1];
  919. }
  920. template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
  921. static __global__ void dequantize_block(const void * vx, float * y, const int k) {
  922. const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
  923. if (i >= k) {
  924. return;
  925. }
  926. const int ib = i/qk; // block index
  927. const int iqs = (i%qk)/qr; // quant index
  928. const int iybs = i - i%qk; // y block start index
  929. const int y_offset = qr == 1 ? 1 : qk/2;
  930. // dequantize
  931. dfloat2 v;
  932. dequantize_kernel(vx, ib, iqs, v);
  933. y[iybs + iqs + 0] = v.x;
  934. y[iybs + iqs + y_offset] = v.y;
  935. }
  936. template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
  937. static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) {
  938. // qk = quantized weights per x block
  939. // qr = number of quantized weights per data value in x block
  940. const int row = blockIdx.y*blockDim.y + threadIdx.y;
  941. if (row >= nrows) {
  942. return;
  943. }
  944. const int tid = threadIdx.x;
  945. const int iter_stride = 2*GGML_CUDA_DMMV_X;
  946. const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
  947. const int y_offset = qr == 1 ? 1 : qk/2;
  948. // partial sum for each thread
  949. #ifdef GGML_CUDA_DMMV_F16
  950. half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
  951. #else
  952. float tmp = 0.0f;
  953. #endif // GGML_CUDA_DMMV_F16
  954. for (int i = 0; i < ncols; i += iter_stride) {
  955. const int col = i + vals_per_iter*tid;
  956. const int ib = (row*ncols + col)/qk; // x block index
  957. const int iqs = (col%qk)/qr; // x quant index
  958. const int iybs = col - col%qk; // y block start index
  959. // processing >2 values per i iter is faster for fast GPUs
  960. #pragma unroll
  961. for (int j = 0; j < vals_per_iter; j += 2) {
  962. // process 2 vals per j iter
  963. // dequantize
  964. // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
  965. dfloat2 v;
  966. dequantize_kernel(vx, ib, iqs + j/qr, v);
  967. // matrix multiplication
  968. // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
  969. #ifdef GGML_CUDA_DMMV_F16
  970. tmp += __hmul2(v, {
  971. y[iybs + iqs + j/qr + 0],
  972. y[iybs + iqs + j/qr + y_offset]
  973. });
  974. #else
  975. tmp += v.x * y[iybs + iqs + j/qr + 0];
  976. tmp += v.y * y[iybs + iqs + j/qr + y_offset];
  977. #endif // GGML_CUDA_DMMV_F16
  978. }
  979. }
  980. // sum up partial sums and write back result
  981. __syncthreads();
  982. #pragma unroll
  983. for (int mask = 16; mask > 0; mask >>= 1) {
  984. tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
  985. }
  986. if (tid == 0) {
  987. #ifdef GGML_CUDA_DMMV_F16
  988. dst[row] = tmp.x + tmp.y;
  989. #else
  990. dst[row] = tmp;
  991. #endif // GGML_CUDA_DMMV_F16
  992. }
  993. }
  994. static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
  995. const half * x = (const half *) vx;
  996. const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
  997. const int channel = blockDim.z*blockIdx.z + threadIdx.z;
  998. const int nrows_y = ncols_x;
  999. const int nrows_dst = nrows_x;
  1000. const int row_dst = row_x;
  1001. float tmp = 0.0f;
  1002. for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) {
  1003. const int col_x = col_x0 + threadIdx.x;
  1004. if (col_x >= ncols_x) {
  1005. break;
  1006. }
  1007. // x is transposed and permuted
  1008. const int ix = row_x*nchannels_x*ncols_x + channel*ncols_x + col_x;
  1009. const float xi = __half2float(x[ix]);
  1010. const int row_y = col_x;
  1011. // y is not transposed but permuted
  1012. const int iy = channel*nrows_y + row_y;
  1013. tmp += xi * y[iy];
  1014. }
  1015. // dst is not transposed and not permuted
  1016. const int idst = channel*nrows_dst + row_dst;
  1017. // sum up partial sums and write back result
  1018. __syncthreads();
  1019. #pragma unroll
  1020. for (int mask = 16; mask > 0; mask >>= 1) {
  1021. tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
  1022. }
  1023. if (threadIdx.x == 0) {
  1024. dst[idst] = tmp;
  1025. }
  1026. }
  1027. static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
  1028. const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
  1029. const int row_stride_x, const int channel_stride_x) {
  1030. const half * x = (const half *) vx;
  1031. const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
  1032. const int channel = blockDim.z*blockIdx.z + threadIdx.z;
  1033. const int nrows_y = ncols_x;
  1034. const int nrows_dst = nrows_x;
  1035. const int row_dst = row_x;
  1036. const int idst = channel*nrows_dst + row_dst;
  1037. float tmp = 0.0f;
  1038. for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) {
  1039. const int col_x = col_x0 + threadIdx.x;
  1040. if (col_x >= ncols_x) {
  1041. break;
  1042. }
  1043. const int ix = channel*channel_stride_x + row_x*row_stride_x + col_x;
  1044. const float xi = __half2float(x[ix]);
  1045. const int row_y = col_x;
  1046. const int iy = channel*nrows_y + row_y;
  1047. tmp += xi * y[iy];
  1048. }
  1049. // sum up partial sums and write back result
  1050. __syncthreads();
  1051. #pragma unroll
  1052. for (int mask = 16; mask > 0; mask >>= 1) {
  1053. tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
  1054. }
  1055. if (threadIdx.x == 0) {
  1056. dst[idst] = tmp;
  1057. }
  1058. }
  1059. static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
  1060. const float * xi = (const float *) cxi;
  1061. float * dsti = (float *) cdsti;
  1062. *dsti = *xi;
  1063. }
  1064. static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
  1065. const float * xi = (const float *) cxi;
  1066. half * dsti = (half *) cdsti;
  1067. *dsti = __float2half(*xi);
  1068. }
  1069. template <cpy_kernel_t cpy_1>
  1070. static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
  1071. const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
  1072. const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) {
  1073. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  1074. if (i >= ne) {
  1075. return;
  1076. }
  1077. // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
  1078. // then combine those indices with the corresponding byte offsets to get the total offsets
  1079. const int i02 = i / (ne00*ne01);
  1080. const int i01 = (i - i02*ne01*ne00) / ne00;
  1081. const int i00 = i - i02*ne01*ne00 - i01*ne00;
  1082. const int x_offset = i00*nb00 + i01*nb01 + i02*nb02;
  1083. const int i12 = i / (ne10*ne11);
  1084. const int i11 = (i - i12*ne10*ne11) / ne10;
  1085. const int i10 = i - i12*ne10*ne11 - i11*ne10;
  1086. const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12;
  1087. cpy_1(cx + x_offset, cdst + dst_offset);
  1088. }
  1089. // rope == RoPE == rotary positional embedding
  1090. static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p, const float theta_scale) {
  1091. const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x);
  1092. if (col >= ncols) {
  1093. return;
  1094. }
  1095. const int row = blockDim.y*blockIdx.y + threadIdx.y;
  1096. const int i = row*ncols + col;
  1097. const float theta = p*powf(theta_scale, col/2);
  1098. const float sin_theta = sinf(theta);
  1099. const float cos_theta = cosf(theta);
  1100. const float x0 = x[i + 0];
  1101. const float x1 = x[i + 1];
  1102. dst[i + 0] = x0*cos_theta - x1*sin_theta;
  1103. dst[i + 1] = x0*sin_theta + x1*cos_theta;
  1104. }
  1105. static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
  1106. const int col = blockDim.x*blockIdx.x + threadIdx.x;
  1107. const int row = blockDim.y*blockIdx.y + threadIdx.y;
  1108. if (col >= ncols) {
  1109. return;
  1110. }
  1111. const int i = row*ncols + col;
  1112. // dst[i] = col > n_past + row ? -INFINITY : x[i];
  1113. dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
  1114. }
  1115. // the CUDA soft max implementation differs from the CPU implementation
  1116. // instead of doubles floats are used
  1117. // values are also not normalized to the maximum value by subtracting it in the exponential function
  1118. // theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine
  1119. static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
  1120. const int row = blockDim.y*blockIdx.y + threadIdx.y;
  1121. const int block_size = blockDim.x;
  1122. const int tid = threadIdx.x;
  1123. float tmp = 0.0;
  1124. for (int block_start = 0; block_start < ncols; block_start += block_size) {
  1125. const int col = block_start + tid;
  1126. if (col >= ncols) {
  1127. break;
  1128. }
  1129. const int i = row*ncols + col;
  1130. const float val = expf(x[i]);
  1131. tmp += val;
  1132. dst[i] = val;
  1133. }
  1134. // sum up partial sums
  1135. __syncthreads();
  1136. #pragma unroll
  1137. for (int mask = 16; mask > 0; mask >>= 1) {
  1138. tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
  1139. }
  1140. for (int block_start = 0; block_start < ncols; block_start += block_size) {
  1141. const int col = block_start + tid;
  1142. if (col >= ncols) {
  1143. break;
  1144. }
  1145. const int i = row*ncols + col;
  1146. dst[i] /= tmp;
  1147. }
  1148. }
  1149. static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) {
  1150. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  1151. if (i >= k) {
  1152. return;
  1153. }
  1154. dst[i] = scale * x[i];
  1155. }
  1156. static void add_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) {
  1157. const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
  1158. add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
  1159. }
  1160. static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
  1161. const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
  1162. add_f16_f32_f16<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
  1163. }
  1164. static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
  1165. const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
  1166. mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
  1167. }
  1168. static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
  1169. const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
  1170. silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
  1171. }
  1172. static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  1173. GGML_ASSERT(ncols % WARP_SIZE == 0);
  1174. const dim3 block_dims(WARP_SIZE, 1, 1);
  1175. rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
  1176. }
  1177. static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  1178. const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
  1179. dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  1180. }
  1181. static void dequantize_row_q4_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  1182. const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
  1183. dequantize_block<QK4_1, QR4_1, dequantize_q4_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  1184. }
  1185. static void dequantize_row_q5_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  1186. const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
  1187. dequantize_block<QK5_0, QR5_0, dequantize_q5_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  1188. }
  1189. static void dequantize_row_q5_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  1190. const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
  1191. dequantize_block<QK5_1, QR5_1, dequantize_q5_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  1192. }
  1193. static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  1194. const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
  1195. dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  1196. }
  1197. static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  1198. const int nb = k / QK_K;
  1199. #if QK_K == 256
  1200. dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
  1201. #else
  1202. dequantize_block_q2_K<<<nb, 32, 0, stream>>>(vx, y);
  1203. #endif
  1204. }
  1205. static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  1206. const int nb = k / QK_K;
  1207. #if QK_K == 256
  1208. dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
  1209. #else
  1210. dequantize_block_q3_K<<<nb, 32, 0, stream>>>(vx, y);
  1211. #endif
  1212. }
  1213. static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  1214. const int nb = k / QK_K;
  1215. dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
  1216. }
  1217. static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  1218. const int nb = k / QK_K;
  1219. #if QK_K == 256
  1220. dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
  1221. #else
  1222. dequantize_block_q5_K<<<nb, 32, 0, stream>>>(vx, y);
  1223. #endif
  1224. }
  1225. static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  1226. const int nb = k / QK_K;
  1227. #if QK_K == 256
  1228. dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
  1229. #else
  1230. dequantize_block_q6_K<<<nb, 32, 0, stream>>>(vx, y);
  1231. #endif
  1232. }
  1233. static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  1234. GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
  1235. const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
  1236. const dim3 block_nums(1, block_num_y, 1);
  1237. const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
  1238. dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
  1239. <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  1240. }
  1241. static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  1242. GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
  1243. const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
  1244. const dim3 block_nums(1, block_num_y, 1);
  1245. const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
  1246. dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
  1247. <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  1248. }
  1249. static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  1250. GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
  1251. const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
  1252. const dim3 block_nums(1, block_num_y, 1);
  1253. const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
  1254. dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
  1255. <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  1256. }
  1257. static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  1258. GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
  1259. const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
  1260. const dim3 block_nums(1, block_num_y, 1);
  1261. const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
  1262. dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
  1263. <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  1264. }
  1265. static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  1266. GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
  1267. const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
  1268. const dim3 block_nums(1, block_num_y, 1);
  1269. const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
  1270. dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
  1271. <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  1272. }
  1273. static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  1274. GGML_ASSERT(ncols % QK_K == 0);
  1275. const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
  1276. const int block_num_y = (nrows + ny - 1) / ny;
  1277. const dim3 block_nums(1, block_num_y, 1);
  1278. const dim3 block_dims(32, ny, 1);
  1279. dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  1280. }
  1281. static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  1282. GGML_ASSERT(ncols % QK_K == 0);
  1283. const int ny = 2 / K_QUANTS_PER_ITERATION;
  1284. const int block_num_y = (nrows + ny - 1) / ny;
  1285. const dim3 block_nums(1, block_num_y, 1);
  1286. const dim3 block_dims(32, ny, 1);
  1287. dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  1288. }
  1289. static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  1290. GGML_ASSERT(ncols % QK_K == 0);
  1291. const int ny = 2 / K_QUANTS_PER_ITERATION;
  1292. const int block_num_y = (nrows + ny - 1) / ny;
  1293. const dim3 block_nums(1, block_num_y, 1);
  1294. const dim3 block_dims(32, ny, 1);
  1295. dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  1296. }
  1297. static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  1298. GGML_ASSERT(ncols % QK_K == 0);
  1299. const dim3 block_dims(32, 1, 1);
  1300. dequantize_mul_mat_vec_q5_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols);
  1301. }
  1302. static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  1303. GGML_ASSERT(ncols % QK_K == 0);
  1304. const int ny = 2 / K_QUANTS_PER_ITERATION;
  1305. const int block_num_y = (nrows + ny - 1) / ny;
  1306. const dim3 block_nums(1, block_num_y, 1);
  1307. const dim3 block_dims(32, ny, 1);
  1308. dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  1309. }
  1310. static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  1311. const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
  1312. dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  1313. }
  1314. static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  1315. GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
  1316. const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
  1317. const dim3 block_nums(1, block_num_y, 1);
  1318. const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
  1319. dequantize_mul_mat_vec<1, 1, convert_f16>
  1320. <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  1321. }
  1322. static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
  1323. switch (type) {
  1324. case GGML_TYPE_Q4_0:
  1325. return dequantize_row_q4_0_cuda;
  1326. case GGML_TYPE_Q4_1:
  1327. return dequantize_row_q4_1_cuda;
  1328. case GGML_TYPE_Q5_0:
  1329. return dequantize_row_q5_0_cuda;
  1330. case GGML_TYPE_Q5_1:
  1331. return dequantize_row_q5_1_cuda;
  1332. case GGML_TYPE_Q8_0:
  1333. return dequantize_row_q8_0_cuda;
  1334. case GGML_TYPE_Q2_K:
  1335. return dequantize_row_q2_K_cuda;
  1336. case GGML_TYPE_Q3_K:
  1337. return dequantize_row_q3_K_cuda;
  1338. case GGML_TYPE_Q4_K:
  1339. return dequantize_row_q4_K_cuda;
  1340. case GGML_TYPE_Q5_K:
  1341. return dequantize_row_q5_K_cuda;
  1342. case GGML_TYPE_Q6_K:
  1343. return dequantize_row_q6_K_cuda;
  1344. case GGML_TYPE_F16:
  1345. return convert_fp16_to_fp32_cuda;
  1346. default:
  1347. return nullptr;
  1348. }
  1349. }
  1350. static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, cudaStream_t stream) {
  1351. const dim3 block_nums(1, nrows_x, nchannels_x);
  1352. const dim3 block_dims(WARP_SIZE, 1, 1);
  1353. mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x);
  1354. }
  1355. static void ggml_mul_mat_vec_nc_f16_f32_cuda(
  1356. const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x,
  1357. const int nchannels_x, const int channel_stride_x, cudaStream_t stream) {
  1358. const dim3 block_nums(1, nrows_x, nchannels_x);
  1359. const dim3 block_dims(WARP_SIZE, 1, 1);
  1360. mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>>
  1361. (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x);
  1362. }
  1363. static void ggml_cpy_f32_f32_cuda(
  1364. const char * cx, char * cdst, const int ne,
  1365. const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
  1366. const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
  1367. const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
  1368. cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
  1369. (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
  1370. }
  1371. static void ggml_cpy_f32_f16_cuda(
  1372. const char * cx, char * cdst, const int ne,
  1373. const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
  1374. const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
  1375. const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
  1376. cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
  1377. (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
  1378. }
  1379. static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
  1380. const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
  1381. scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
  1382. }
  1383. static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float theta_scale, cudaStream_t stream) {
  1384. GGML_ASSERT(nrows % 2 == 0);
  1385. const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1);
  1386. const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
  1387. const dim3 block_nums(num_blocks_x, nrows, 1);
  1388. rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, theta_scale);
  1389. }
  1390. static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
  1391. const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1);
  1392. const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
  1393. const dim3 block_nums(block_num_x, nrows_x, 1);
  1394. diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
  1395. }
  1396. static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
  1397. const dim3 block_dims(WARP_SIZE, 1, 1);
  1398. const dim3 block_nums(1, nrows_x, 1);
  1399. soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
  1400. }
  1401. // buffer pool for cuda
  1402. #define MAX_CUDA_BUFFERS 256
  1403. struct scoped_spin_lock {
  1404. std::atomic_flag& lock;
  1405. scoped_spin_lock(std::atomic_flag& lock) : lock(lock) {
  1406. while (lock.test_and_set(std::memory_order_acquire)) {
  1407. ; // spin
  1408. }
  1409. }
  1410. ~scoped_spin_lock() {
  1411. lock.clear(std::memory_order_release);
  1412. }
  1413. scoped_spin_lock(const scoped_spin_lock&) = delete;
  1414. scoped_spin_lock& operator=(const scoped_spin_lock&) = delete;
  1415. };
  1416. struct cuda_buffer {
  1417. void * ptr = nullptr;
  1418. size_t size = 0;
  1419. };
  1420. static cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_CUDA_BUFFERS];
  1421. static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
  1422. static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
  1423. scoped_spin_lock lock(g_cuda_pool_lock);
  1424. int id;
  1425. CUDA_CHECK(cudaGetDevice(&id));
  1426. for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
  1427. cuda_buffer& b = g_cuda_buffer_pool[id][i];
  1428. if (b.size >= size && b.ptr != nullptr) {
  1429. void * ptr = b.ptr;
  1430. *actual_size = b.size;
  1431. b.ptr = nullptr;
  1432. b.size = 0;
  1433. return ptr;
  1434. }
  1435. }
  1436. void * ptr;
  1437. CUDA_CHECK(cudaMalloc((void **) &ptr, size));
  1438. *actual_size = size;
  1439. return ptr;
  1440. }
  1441. static void ggml_cuda_pool_free(void * ptr, size_t size) {
  1442. scoped_spin_lock lock(g_cuda_pool_lock);
  1443. int id;
  1444. CUDA_CHECK(cudaGetDevice(&id));
  1445. for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
  1446. cuda_buffer& b = g_cuda_buffer_pool[id][i];
  1447. if (b.ptr == nullptr) {
  1448. b.ptr = ptr;
  1449. b.size = size;
  1450. return;
  1451. }
  1452. }
  1453. fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
  1454. CUDA_CHECK(cudaFree(ptr));
  1455. }
  1456. static void * g_scratch_buffer = nullptr;
  1457. static size_t g_scratch_size = 1024*1024*1024; // 1 GB by default
  1458. static size_t g_scratch_offset = 0;
  1459. static int g_device_count = -1;
  1460. static int g_main_device = 0;
  1461. static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
  1462. static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
  1463. static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES] = { nullptr };
  1464. void ggml_init_cublas() {
  1465. static bool initialized = false;
  1466. if (!initialized) {
  1467. CUDA_CHECK(cudaGetDeviceCount(&g_device_count));
  1468. GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
  1469. int64_t total_vram = 0;
  1470. fprintf(stderr, "%s: found %d CUDA devices:\n", __func__, g_device_count);
  1471. for (int id = 0; id < g_device_count; ++id) {
  1472. cudaDeviceProp prop;
  1473. CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
  1474. fprintf(stderr, " Device %d: %s\n", id, prop.name);
  1475. g_tensor_split[id] = total_vram;
  1476. total_vram += prop.totalGlobalMem;
  1477. }
  1478. for (int id = 0; id < g_device_count; ++id) {
  1479. g_tensor_split[id] /= total_vram;
  1480. }
  1481. for (int id = 0; id < g_device_count; ++id) {
  1482. CUDA_CHECK(cudaSetDevice(id));
  1483. // create main stream
  1484. CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[id], cudaStreamNonBlocking));
  1485. // create cublas handle
  1486. CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
  1487. CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH));
  1488. }
  1489. // configure logging to stdout
  1490. // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
  1491. initialized = true;
  1492. }
  1493. }
  1494. void ggml_cuda_set_tensor_split(const float * tensor_split) {
  1495. bool all_zero = true;
  1496. for (int i = 0; i < g_device_count; ++i) {
  1497. if (tensor_split[i] != 0.0f) {
  1498. all_zero = false;
  1499. break;
  1500. }
  1501. }
  1502. if (all_zero) {
  1503. return;
  1504. }
  1505. float split_sum = 0.0f;
  1506. for (int i = 0; i < g_device_count; ++i) {
  1507. g_tensor_split[i] = split_sum;
  1508. split_sum += tensor_split[i];
  1509. }
  1510. for (int i = 0; i < g_device_count; ++i) {
  1511. g_tensor_split[i] /= split_sum;
  1512. }
  1513. }
  1514. void * ggml_cuda_host_malloc(size_t size) {
  1515. if (getenv("GGML_CUDA_NO_PINNED") != nullptr) {
  1516. return nullptr;
  1517. }
  1518. void * ptr = nullptr;
  1519. cudaError_t err = cudaMallocHost((void **) &ptr, size);
  1520. if (err != cudaSuccess) {
  1521. // The allocation error can be bypassed. A null ptr will assigned out of this function.
  1522. // This can fixed the OOM error in WSL.
  1523. cudaGetLastError();
  1524. fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n",
  1525. size/1024.0/1024.0, cudaGetErrorString(err));
  1526. return nullptr;
  1527. }
  1528. return ptr;
  1529. }
  1530. void ggml_cuda_host_free(void * ptr) {
  1531. CUDA_CHECK(cudaFreeHost(ptr));
  1532. }
  1533. static cudaError_t ggml_cuda_cpy_tensor_2d(
  1534. void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
  1535. cudaMemcpyKind kind;
  1536. char * src_ptr;
  1537. if (src->backend == GGML_BACKEND_CPU) {
  1538. kind = cudaMemcpyHostToDevice;
  1539. src_ptr = (char *) src->data;
  1540. } else if (src->backend == GGML_BACKEND_GPU) {
  1541. kind = cudaMemcpyDeviceToDevice;
  1542. struct ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra;
  1543. int id;
  1544. CUDA_CHECK(cudaGetDevice(&id));
  1545. src_ptr = (char *) extra->data_device[id];
  1546. } else {
  1547. GGML_ASSERT(false);
  1548. }
  1549. char * dst_ptr = (char *) dst;
  1550. const int64_t ne0 = src->ne[0];
  1551. const int64_t nb0 = src->nb[0];
  1552. const int64_t nb1 = src->nb[1];
  1553. const int64_t nb2 = src->nb[2];
  1554. const int64_t nb3 = src->nb[3];
  1555. const enum ggml_type type = src->type;
  1556. const int64_t ts = ggml_type_size(type);
  1557. const int64_t bs = ggml_blck_size(type);
  1558. int64_t i1_diff = i1_high - i1_low;
  1559. const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
  1560. if (nb0 == ts && nb1 == ts*ne0/bs) {
  1561. return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, kind, stream);
  1562. } else if (nb0 == ts) {
  1563. return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, kind, stream);
  1564. } else {
  1565. for (int64_t i1 = 0; i1 < i1_diff; i1++) {
  1566. const void * rx = (const void *) ((const char *) x + i1*nb1);
  1567. void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
  1568. // pretend the row is a matrix with cols=1
  1569. cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, kind, stream);
  1570. if (r != cudaSuccess) return r;
  1571. }
  1572. return cudaSuccess;
  1573. }
  1574. }
  1575. inline void ggml_cuda_op_add(
  1576. const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
  1577. float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
  1578. cudaStream_t & cudaStream_main){
  1579. GGML_ASSERT(src0_ddq_i != nullptr || src0_ddf_i != nullptr);
  1580. GGML_ASSERT(src1_ddf_i != nullptr);
  1581. GGML_ASSERT(dst_ddf_i != nullptr);
  1582. const int64_t ne0 = src0->ne[0];
  1583. const int64_t i01_diff = i01_high - i01_low;
  1584. // compute
  1585. if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
  1586. add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main);
  1587. } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
  1588. add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne0*i01_diff, cudaStream_main);
  1589. } else {
  1590. GGML_ASSERT(false);
  1591. }
  1592. (void) src1;
  1593. (void) dst;
  1594. (void) src0_ddq_i;
  1595. (void) i02;
  1596. (void) i1;
  1597. }
  1598. inline void ggml_cuda_op_mul(
  1599. const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
  1600. float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
  1601. cudaStream_t & cudaStream_main){
  1602. GGML_ASSERT(src0_ddf_i != nullptr);
  1603. GGML_ASSERT(src1_ddf_i != nullptr);
  1604. GGML_ASSERT(dst_ddf_i != nullptr);
  1605. const int64_t ne00 = src0->ne[0];
  1606. const int64_t ne10 = src1->ne[0];
  1607. const int64_t ne11 = src1->ne[1];
  1608. for (int64_t i01 = i01_low; i01 < i01_high; i01++) {
  1609. const int64_t i11 = i1*ne11 + i01%ne11; // broadcast src1 across src0
  1610. float * src0_ddf_i01 = src0_ddf_i + i01*ne00;
  1611. float * src1_ddf_i01 = src1_ddf_i + i11*ne10;
  1612. float * dst_ddf_i01 = dst_ddf_i + i01*ne00;
  1613. // compute
  1614. mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
  1615. }
  1616. (void) dst;
  1617. (void) src0_ddq_i;
  1618. (void) i02;
  1619. }
  1620. inline void ggml_cuda_op_silu(
  1621. const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
  1622. float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
  1623. cudaStream_t & cudaStream_main){
  1624. GGML_ASSERT(src0_ddf_i != nullptr);
  1625. GGML_ASSERT(dst_ddf_i != nullptr);
  1626. const int64_t ne00 = src0->ne[0];
  1627. const int64_t i01_diff = i01_high - i01_low;
  1628. // compute
  1629. silu_f32_cuda(src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
  1630. (void) src1;
  1631. (void) dst;
  1632. (void) src0_ddq_i;
  1633. (void) src1_ddf_i;
  1634. (void) i02;
  1635. (void) i1;
  1636. }
  1637. inline void ggml_cuda_op_rms_norm(
  1638. const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
  1639. float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
  1640. cudaStream_t & cudaStream_main){
  1641. GGML_ASSERT(src0_ddf_i != nullptr);
  1642. GGML_ASSERT(dst_ddf_i != nullptr);
  1643. const int64_t ne00 = src0->ne[0];
  1644. const int64_t i01_diff = i01_high - i01_low;
  1645. // compute
  1646. rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
  1647. (void) src1;
  1648. (void) dst;
  1649. (void) src0_ddq_i;
  1650. (void) src1_ddf_i;
  1651. (void) i02;
  1652. (void) i1;
  1653. }
  1654. inline void ggml_cuda_op_dequantize_mul_mat_vec(
  1655. const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
  1656. float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
  1657. cudaStream_t & cudaStream_main){
  1658. GGML_ASSERT(src0_ddq_i != nullptr);
  1659. GGML_ASSERT(src1_ddf_i != nullptr);
  1660. GGML_ASSERT(dst_ddf_i != nullptr);
  1661. const int64_t ne00 = src0->ne[0];
  1662. const int64_t nrows = i01_high - i01_low;
  1663. // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
  1664. #ifdef GGML_CUDA_DMMV_F16
  1665. size_t ash;
  1666. dfloat * src1_dfloat = nullptr; // dfloat == half
  1667. bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
  1668. src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
  1669. src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
  1670. if (src1_convert_f16) {
  1671. src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash);
  1672. ggml_cpy_f32_f16_cuda((char *) src1_ddf_i, (char *) src1_dfloat, ne00,
  1673. ne00, 1, sizeof(float), 0, 0,
  1674. ne00, 1, sizeof(half), 0, 0, cudaStream_main);
  1675. }
  1676. #else
  1677. dfloat * src1_dfloat = src1_ddf_i; // dfloat == float, no conversion
  1678. #endif // GGML_CUDA_DMMV_F16
  1679. switch (src0->type) {
  1680. case GGML_TYPE_Q4_0:
  1681. dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
  1682. break;
  1683. case GGML_TYPE_Q4_1:
  1684. dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
  1685. break;
  1686. case GGML_TYPE_Q5_0:
  1687. dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
  1688. break;
  1689. case GGML_TYPE_Q5_1:
  1690. dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
  1691. break;
  1692. case GGML_TYPE_Q8_0:
  1693. dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
  1694. break;
  1695. case GGML_TYPE_Q2_K:
  1696. dequantize_mul_mat_vec_q2_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
  1697. break;
  1698. case GGML_TYPE_Q3_K:
  1699. dequantize_mul_mat_vec_q3_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
  1700. break;
  1701. case GGML_TYPE_Q4_K:
  1702. dequantize_mul_mat_vec_q4_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
  1703. break;
  1704. case GGML_TYPE_Q5_K:
  1705. dequantize_mul_mat_vec_q5_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
  1706. break;
  1707. case GGML_TYPE_Q6_K:
  1708. dequantize_mul_mat_vec_q6_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
  1709. break;
  1710. case GGML_TYPE_F16:
  1711. convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
  1712. break;
  1713. default:
  1714. GGML_ASSERT(false);
  1715. break;
  1716. }
  1717. #ifdef GGML_CUDA_DMMV_F16
  1718. if (src1_convert_f16) {
  1719. ggml_cuda_pool_free(src1_dfloat, ash);
  1720. }
  1721. #endif // GGML_CUDA_DMMV_F16
  1722. (void) src1;
  1723. (void) dst;
  1724. (void) src0_ddf_i;
  1725. (void) i02;
  1726. (void) i1;
  1727. }
  1728. inline void ggml_cuda_op_mul_mat_cublas(
  1729. const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
  1730. float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
  1731. cudaStream_t & cudaStream_main){
  1732. GGML_ASSERT(src0_ddf_i != nullptr);
  1733. GGML_ASSERT(src1_ddf_i != nullptr);
  1734. GGML_ASSERT(dst_ddf_i != nullptr);
  1735. const float alpha = 1.0f;
  1736. const float beta = 0.0f;
  1737. const int64_t ne00 = src0->ne[0];
  1738. const int64_t ne10 = src1->ne[0];
  1739. const int64_t ne11 = src1->ne[1];
  1740. const int64_t ne0 = dst->ne[0];
  1741. const int64_t i01_diff = i01_high - i01_low;
  1742. int id;
  1743. CUDA_CHECK(cudaGetDevice(&id));
  1744. // the main device has a larger memory buffer to hold the results from all GPUs
  1745. // ldc == nrows of the matrix that cuBLAS writes into
  1746. int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : i01_diff;
  1747. CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], cudaStream_main));
  1748. CUBLAS_CHECK(
  1749. cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
  1750. i01_diff, ne11, ne10,
  1751. &alpha, src0_ddf_i, ne00,
  1752. src1_ddf_i, ne10,
  1753. &beta, dst_ddf_i, ldc));
  1754. (void) dst;
  1755. (void) src0_ddq_i;
  1756. (void) i02;
  1757. (void) i1;
  1758. }
  1759. inline void ggml_cuda_op_rope(
  1760. const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
  1761. float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
  1762. cudaStream_t & cudaStream_main){
  1763. GGML_ASSERT(src0_ddf_i != nullptr);
  1764. GGML_ASSERT(dst_ddf_i != nullptr);
  1765. const int64_t ne00 = src0->ne[0];
  1766. const int64_t i01_diff = i01_high - i01_low;
  1767. const int n_past = ((int32_t *) src1->data)[0];
  1768. const int n_dims = ((int32_t *) src1->data)[1];
  1769. const int mode = ((int32_t *) src1->data)[2];
  1770. GGML_ASSERT(mode == 0);
  1771. const float theta_scale = powf(10000.0, -2.0f/n_dims);
  1772. const float p = ((mode & 1) == 0 ? n_past + i02 : i02);
  1773. // compute
  1774. rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
  1775. (void) dst;
  1776. (void) src0_ddq_i;
  1777. (void) src1_ddf_i;
  1778. (void) i1;
  1779. }
  1780. inline void ggml_cuda_op_diag_mask_inf(
  1781. const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
  1782. float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
  1783. cudaStream_t & cudaStream_main){
  1784. GGML_ASSERT(src0_ddf_i != nullptr);
  1785. GGML_ASSERT(dst_ddf_i != nullptr);
  1786. const int64_t ne00 = src0->ne[0];
  1787. const int64_t ne01 = src0->ne[1];
  1788. const int64_t i01_diff = i01_high - i01_low;
  1789. const int n_past = ((int32_t *) src1->data)[0];
  1790. // compute
  1791. diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main);
  1792. (void) dst;
  1793. (void) src0_ddq_i;
  1794. (void) src1_ddf_i;
  1795. (void) i02;
  1796. (void) i1;
  1797. }
  1798. inline void ggml_cuda_op_soft_max(
  1799. const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
  1800. float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
  1801. cudaStream_t & cudaStream_main){
  1802. GGML_ASSERT(src0_ddf_i != nullptr);
  1803. GGML_ASSERT(dst_ddf_i != nullptr);
  1804. const int64_t ne00 = src0->ne[0];
  1805. const int64_t i01_diff = i01_high - i01_low;
  1806. // compute
  1807. soft_max_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
  1808. (void) src1;
  1809. (void) dst;
  1810. (void) src0_ddq_i;
  1811. (void) src1_ddf_i;
  1812. (void) i02;
  1813. (void) i1;
  1814. }
  1815. inline void ggml_cuda_op_scale(
  1816. const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
  1817. float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
  1818. cudaStream_t & cudaStream_main){
  1819. GGML_ASSERT(src0_ddf_i != nullptr);
  1820. GGML_ASSERT(dst_ddf_i != nullptr);
  1821. const float scale = ((float *) src1->data)[0];
  1822. const int64_t ne00 = src0->ne[0];
  1823. const int64_t i01_diff = i01_high - i01_low;
  1824. // compute
  1825. scale_f32_cuda(src0_ddf_i, dst_ddf_i, scale, ne00*i01_diff, cudaStream_main);
  1826. CUDA_CHECK(cudaGetLastError());
  1827. (void) src1;
  1828. (void) dst;
  1829. (void) src0_ddq_i;
  1830. (void) src1_ddf_i;
  1831. (void) i02;
  1832. (void) i1;
  1833. }
  1834. static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
  1835. ggml_cuda_op_t op, bool src0_needs_f32, bool flatten_rows) {
  1836. const int64_t ne00 = src0->ne[0];
  1837. const int64_t ne01 = src0->ne[1];
  1838. const int64_t ne02 = src0->ne[2];
  1839. const int64_t ne03 = src0->ne[3];
  1840. const int64_t nrows0 = ggml_nrows(src0);
  1841. const bool use_src1 = src1 != nullptr;
  1842. const int64_t ne10 = use_src1 ? src1->ne[0] : 1;
  1843. const int64_t ne11 = use_src1 ? src1->ne[1] : 1;
  1844. const int64_t ne12 = use_src1 ? src1->ne[2] : 1;
  1845. const int64_t ne13 = use_src1 ? src1->ne[3] : 1;
  1846. const int64_t ne0 = dst->ne[0];
  1847. const int64_t ne1 = dst->ne[1];
  1848. const int nb2 = dst->nb[2];
  1849. const int nb3 = dst->nb[3];
  1850. GGML_ASSERT(dst->backend != GGML_BACKEND_GPU_SPLIT);
  1851. GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT);
  1852. // strides for iteration over dims 3 and 2
  1853. const int64_t num_iters = flatten_rows ? 1 : ne02 * ne03;
  1854. const int64_t stride_mod = flatten_rows ? ne02 * ne03 : 1;
  1855. const int64_t src0_stride = ne00 * ne01 * stride_mod;
  1856. const int64_t src1_stride = ne10 * ne11 * stride_mod;
  1857. const int64_t dst_stride = ne0 * ne1 * stride_mod;
  1858. const size_t src0_ts = ggml_type_size(src0->type);
  1859. const size_t src0_bs = ggml_blck_size(src0->type);
  1860. struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
  1861. struct ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
  1862. struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
  1863. const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
  1864. const bool src0_is_contiguous = ggml_is_contiguous(src0);
  1865. const bool src0_is_f32 = src0->type == GGML_TYPE_F32;
  1866. const bool src1_is_contiguous = use_src1 && ggml_is_contiguous(src1);
  1867. const bool src1_stays_on_host = use_src1 && (
  1868. dst->op == GGML_OP_SCALE || dst->op == GGML_OP_DIAG_MASK_INF || dst->op == GGML_OP_ROPE);
  1869. const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
  1870. const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
  1871. // dd = data device
  1872. char * src0_ddq[GGML_CUDA_MAX_DEVICES] = {nullptr}; // quantized
  1873. float * src0_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; // float
  1874. float * src1_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr};
  1875. float * dst_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr};
  1876. // asq = actual size quantized, asf = actual size float
  1877. size_t src0_asq[GGML_CUDA_MAX_DEVICES] = {0};
  1878. size_t src0_asf[GGML_CUDA_MAX_DEVICES] = {0};
  1879. size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
  1880. size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
  1881. // if multiple devices are used they need to wait for the main device
  1882. // here an event is recorded that signifies that the main device has finished calculating the input data
  1883. if (split && g_device_count > 1) {
  1884. CUDA_CHECK(cudaSetDevice(g_main_device));
  1885. CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device], g_cudaStreams_main[g_main_device]));
  1886. }
  1887. for (int id = 0; id < g_device_count; ++id) {
  1888. if (!split && id != g_main_device) {
  1889. continue;
  1890. }
  1891. const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_GPU && id == g_main_device;
  1892. const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;
  1893. int64_t row_low, row_high;
  1894. if (split) {
  1895. row_low = id == 0 ? 0 : nrows0*g_tensor_split[id];
  1896. row_high = id == g_device_count - 1 ? nrows0 : nrows0*g_tensor_split[id + 1];
  1897. } else {
  1898. row_low = 0;
  1899. row_high = nrows0;
  1900. }
  1901. if (row_low == row_high) {
  1902. continue;
  1903. }
  1904. int64_t row_diff = row_high - row_low;
  1905. cudaSetDevice(id);
  1906. cudaStream_t cudaStream_main = g_cudaStreams_main[id];
  1907. // wait for main GPU data if necessary
  1908. if (split && id != g_main_device) {
  1909. CUDA_CHECK(cudaStreamWaitEvent(cudaStream_main, src0_extra->events[g_main_device]));
  1910. }
  1911. if (src0_on_device && src0_is_contiguous) {
  1912. if (src0_is_f32) {
  1913. src0_ddf[id] = (float *) src0_extra->data_device[id];
  1914. } else {
  1915. src0_ddq[id] = (char *) src0_extra->data_device[id];
  1916. }
  1917. } else {
  1918. if (src0_is_f32) {
  1919. src0_ddf[id] = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_asf[id]);
  1920. } else {
  1921. src0_ddq[id] = (char *) ggml_cuda_pool_malloc(row_diff*ne00 * src0_ts/src0_bs, &src0_asq[id]);
  1922. }
  1923. }
  1924. if (src0_needs_f32 && !src0_is_f32) {
  1925. src0_ddf[id] = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_asf[id]);
  1926. }
  1927. if (use_src1 && !src1_stays_on_host) {
  1928. if (src1_on_device && src1_is_contiguous) {
  1929. src1_ddf[id] = (float *) src1_extra->data_device[id];
  1930. } else {
  1931. src1_ddf[id] = (float *) ggml_cuda_pool_malloc(num_iters*src1_stride * sizeof(float), &src1_asf[id]);
  1932. }
  1933. }
  1934. if (dst_on_device) {
  1935. dst_ddf[id] = (float *) dst_extra->data_device[id];
  1936. } else {
  1937. size_t size_dst_ddf = split ? row_diff*ne1 * sizeof(float) : num_iters*dst_stride * sizeof(float);
  1938. dst_ddf[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_asf[id]);
  1939. }
  1940. const int64_t i03_max = flatten_rows ? 1 : ne03;
  1941. const int64_t i02_max = flatten_rows ? 1 : ne02;
  1942. const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01;
  1943. for (int64_t i03 = 0; i03 < i03_max; i03++) {
  1944. const int64_t i13 = i03 % ne13;
  1945. for (int64_t i02 = 0; i02 < i02_max; i02++) {
  1946. const int64_t i12 = i02 % ne12;
  1947. const int64_t i0 = i03*ne02 + i02;
  1948. // i0 values that contain the lower/upper rows for a split tensor when using multiple GPUs
  1949. const int64_t i0_offset_low = row_low/rows_per_iter;
  1950. const int64_t i0_offset_high = row_high/rows_per_iter;
  1951. int64_t i01_low = 0;
  1952. int64_t i01_high = rows_per_iter;
  1953. if (split) {
  1954. if (i0 < i0_offset_low || i0 > i0_offset_high) {
  1955. continue;
  1956. }
  1957. if (i0 == i0_offset_low) {
  1958. i01_low = row_low % rows_per_iter;
  1959. }
  1960. if (i0 == i0_offset_high) {
  1961. i01_high = row_high % rows_per_iter;
  1962. }
  1963. }
  1964. // There is possibly a bug in the Windows nvcc compiler regarding instruction reordering or optimizing out local variables.
  1965. // Removing the first assert or changing the order of the arguments causes the second assert to fail.
  1966. // Removing both asserts results in i01_high becoming 0 which in turn results in garbage output.
  1967. // The root cause seems to be a problem with i0_offset_high becoming 0 when it should always be >0 (for single GPU).
  1968. GGML_ASSERT(i01_low == 0 || g_device_count > 1);
  1969. GGML_ASSERT(i01_high == rows_per_iter || g_device_count > 1);
  1970. const int64_t i01_diff = i01_high - i01_low;
  1971. if (i01_diff == 0) {
  1972. continue;
  1973. }
  1974. const int64_t i11 = i13*ne12 + i12;
  1975. // for split tensors the data begins at i0 == i0_offset_low
  1976. char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
  1977. float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride;
  1978. float * src1_ddf_i = src1_ddf[id] + i11*src1_stride;
  1979. float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride;
  1980. // for split tensors the data pointer needs to be rounded down
  1981. // to the bin edge for i03, i02 bins beyond the first
  1982. if (i0 - i0_offset_low > 0) {
  1983. GGML_ASSERT(!flatten_rows);
  1984. src0_ddq_i -= (row_low % ne01)*ne00 * src0_ts/src0_bs;
  1985. src0_ddf_i -= (row_low % ne01)*ne00;
  1986. dst_ddf_i -= (row_low % ne0)*ne1;
  1987. }
  1988. // the main device memory buffer can be on VRAM scratch, with space for all partial results
  1989. // in that case an offset on dst_ddf_i is needed
  1990. if (dst->backend == GGML_BACKEND_GPU && id == g_main_device) {
  1991. dst_ddf_i += i01_low; // offset is 0 if no tensor split
  1992. }
  1993. // copy src0, src1 to device if necessary
  1994. if (use_src1 && !src1_stays_on_host) {
  1995. if (src1->backend == GGML_BACKEND_CPU) {
  1996. GGML_ASSERT(!flatten_rows || nrows0 == ggml_nrows(src1));
  1997. int64_t nrows1 = flatten_rows ? nrows0 : ne11;
  1998. CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_main));
  1999. } else if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) {
  2000. if (id != g_main_device) {
  2001. GGML_ASSERT(!flatten_rows);
  2002. float * src1_ddf_i_source = (float *) src1_extra->data_device[g_main_device];
  2003. src1_ddf_i_source += i11*src1_stride;
  2004. CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_stride*sizeof(float),
  2005. cudaMemcpyDeviceToDevice, cudaStream_main));
  2006. }
  2007. } else if (src1_on_device && !src1_is_contiguous) {
  2008. GGML_ASSERT(!split);
  2009. CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, ne11, cudaStream_main));
  2010. } else {
  2011. GGML_ASSERT(false);
  2012. }
  2013. }
  2014. if (!src0_on_device || !src0_is_contiguous) {
  2015. if (src0_is_f32) {
  2016. CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
  2017. } else {
  2018. CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
  2019. }
  2020. }
  2021. // convert src0 to f32 if it is necessary for the ggml_cuda_op
  2022. if (src0_needs_f32 && !src0_is_f32) {
  2023. to_fp32_cuda(src0_ddq_i, src0_ddf_i, i01_diff*ne00, cudaStream_main);
  2024. CUDA_CHECK(cudaGetLastError());
  2025. }
  2026. // do the computation
  2027. op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main);
  2028. CUDA_CHECK(cudaGetLastError());
  2029. // copy dst to host or other device if necessary
  2030. if (!dst_on_device) {
  2031. void * dst_off_device;
  2032. cudaMemcpyKind kind;
  2033. if (dst->backend == GGML_BACKEND_CPU) {
  2034. dst_off_device = dst->data;
  2035. kind = cudaMemcpyDeviceToHost;
  2036. } else if (dst->backend == GGML_BACKEND_GPU) {
  2037. dst_off_device = dst_extra->data_device[g_main_device];
  2038. kind = cudaMemcpyDeviceToDevice;
  2039. } else {
  2040. GGML_ASSERT(false);
  2041. }
  2042. if (split) {
  2043. // src0 = weight matrix is saved as a transposed matrix for better memory layout.
  2044. // dst is NOT transposed.
  2045. // The outputs of cuBLAS matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU.
  2046. // Instead they need to be copied to the correct slice in ne0 = dst row index.
  2047. // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results.
  2048. for (int64_t j = 0; j < ne1; ++j) {
  2049. float * dhf_dst_i = (float *) ((char *) dst_off_device + (j*ne0 + i01_low)*sizeof(float) + i02*nb2 + i03*nb3);
  2050. CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i + j*i01_diff, i01_diff*sizeof(float), kind, cudaStream_main));
  2051. }
  2052. } else {
  2053. float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
  2054. CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, dst_stride*sizeof(float), kind, cudaStream_main));
  2055. }
  2056. }
  2057. // signify to main device that other device is done
  2058. if (split && g_device_count > 1 && id != g_main_device) {
  2059. CUDA_CHECK(cudaEventRecord(src0_extra->events[id], cudaStream_main));
  2060. }
  2061. }
  2062. }
  2063. }
  2064. // wait until each device is finished, then free their buffers
  2065. for (int id = 0; id < g_device_count; ++id) {
  2066. if (src0_asq[id] == 0 && src0_asf[id] == 0 && src1_asf[id] == 0 && dst_asf[id] == 0) {
  2067. continue;
  2068. }
  2069. CUDA_CHECK(cudaSetDevice(id));
  2070. if (src0_asq[id] > 0) {
  2071. ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]);
  2072. }
  2073. if (src0_asf[id] > 0) {
  2074. ggml_cuda_pool_free(src0_ddf[id], src0_asf[id]);
  2075. }
  2076. if (src1_asf[id] > 0) {
  2077. ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
  2078. }
  2079. if (dst_asf[id] > 0) {
  2080. ggml_cuda_pool_free(dst_ddf[id], dst_asf[id]);
  2081. }
  2082. }
  2083. // main device waits for all other devices to be finished
  2084. if (split && g_device_count > 1) {
  2085. CUDA_CHECK(cudaSetDevice(g_main_device));
  2086. for (int id = 0; id < g_device_count; ++id) {
  2087. if (id != g_main_device) {
  2088. CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams_main[g_main_device], src0_extra->events[id]));
  2089. }
  2090. }
  2091. }
  2092. if (dst->backend == GGML_BACKEND_CPU) {
  2093. CUDA_CHECK(cudaSetDevice(g_main_device));
  2094. CUDA_CHECK(cudaDeviceSynchronize());
  2095. }
  2096. }
  2097. void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
  2098. // ggml_cuda_add permits f16 dst even though this could in theory cause problems with the pointer arithmetic in ggml_cuda_op.
  2099. // Due to flatten_rows == true this does in practice not make a difference however.
  2100. // Better solution would be nice but right now that would require disproportionate changes.
  2101. GGML_ASSERT(
  2102. (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) &&
  2103. src1->type == GGML_TYPE_F32 &&
  2104. (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16));
  2105. ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, false, true);
  2106. }
  2107. void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
  2108. GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
  2109. ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul, true, false); // TODO ggml_cuda_op needs modification for flatten
  2110. }
  2111. void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
  2112. GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
  2113. ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true, true);
  2114. }
  2115. void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
  2116. GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
  2117. ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rms_norm, true, true);
  2118. }
  2119. bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
  2120. const int64_t ne10 = src1->ne[0];
  2121. const int64_t ne0 = dst->ne[0];
  2122. const int64_t ne1 = dst->ne[1];
  2123. // TODO: find the optimal values for these
  2124. if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
  2125. src1->type == GGML_TYPE_F32 &&
  2126. dst->type == GGML_TYPE_F32 &&
  2127. (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
  2128. return true;
  2129. }
  2130. return false;
  2131. }
  2132. void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
  2133. GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
  2134. GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
  2135. GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
  2136. GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation
  2137. GGML_ASSERT(src0->type == GGML_TYPE_F16);
  2138. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  2139. const int64_t ne00 = src0->ne[0];
  2140. const int64_t ne01 = src0->ne[1];
  2141. const int64_t ne02 = src0->ne[2];
  2142. CUDA_CHECK(cudaSetDevice(g_main_device));
  2143. cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
  2144. struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
  2145. void * src0_ddq = src0_extra->data_device[g_main_device];
  2146. struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
  2147. float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
  2148. struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
  2149. float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
  2150. ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, cudaStream_main);
  2151. }
  2152. void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
  2153. GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1));
  2154. GGML_ASSERT(!ggml_is_permuted(src0));
  2155. GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
  2156. GGML_ASSERT(src0->type == GGML_TYPE_F16);
  2157. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  2158. const int64_t ne00 = src0->ne[0];
  2159. const int64_t ne01 = src0->ne[1];
  2160. const int64_t ne02 = src0->ne[2];
  2161. const int64_t nb01 = src0->nb[1];
  2162. const int64_t nb02 = src0->nb[2];
  2163. CUDA_CHECK(cudaSetDevice(g_main_device));
  2164. cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
  2165. struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
  2166. void * src0_ddq = src0_extra->data_device[g_main_device];
  2167. struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
  2168. float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
  2169. struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
  2170. float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
  2171. const int row_stride_x = nb01 / sizeof(half);
  2172. const int channel_stride_x = nb02 / sizeof(half);
  2173. ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, channel_stride_x, cudaStream_main);
  2174. }
  2175. void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
  2176. bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
  2177. src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU;
  2178. if (all_on_device && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
  2179. ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
  2180. } else if (all_on_device && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) {
  2181. ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
  2182. }else if (src0->type == GGML_TYPE_F32) {
  2183. ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
  2184. } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
  2185. if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[1] % GGML_CUDA_DMMV_Y == 0) {
  2186. ggml_cuda_op(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false, false);
  2187. } else {
  2188. ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
  2189. }
  2190. } else {
  2191. GGML_ASSERT(false);
  2192. }
  2193. }
  2194. void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
  2195. GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
  2196. ggml_cuda_op(src0, src1, dst, ggml_cuda_op_scale, true, true);
  2197. }
  2198. void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
  2199. const int64_t ne = ggml_nelements(src0);
  2200. GGML_ASSERT(ne == ggml_nelements(src1));
  2201. GGML_ASSERT(src0->backend == GGML_BACKEND_GPU);
  2202. GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
  2203. GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
  2204. GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
  2205. const int64_t ne00 = src0->ne[0];
  2206. const int64_t ne01 = src0->ne[1];
  2207. GGML_ASSERT(src0->ne[3] == 1);
  2208. const int64_t nb00 = src0->nb[0];
  2209. const int64_t nb01 = src0->nb[1];
  2210. const int64_t nb02 = src0->nb[2];
  2211. const int64_t ne10 = src1->ne[0];
  2212. const int64_t ne11 = src1->ne[1];
  2213. GGML_ASSERT(src1->ne[3] == 1);
  2214. const int64_t nb10 = src1->nb[0];
  2215. const int64_t nb11 = src1->nb[1];
  2216. const int64_t nb12 = src1->nb[2];
  2217. CUDA_CHECK(cudaSetDevice(g_main_device));
  2218. cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
  2219. const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
  2220. const struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
  2221. char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
  2222. char * src1_ddc = (char *) src1_extra->data_device[g_main_device];
  2223. if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
  2224. ggml_cpy_f32_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
  2225. ne10, ne11, nb10, nb11, nb12, cudaStream_main);
  2226. } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
  2227. ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
  2228. ne10, ne11, nb10, nb11, nb12, cudaStream_main);
  2229. } else {
  2230. GGML_ASSERT(false);
  2231. }
  2232. (void) dst;
  2233. }
  2234. void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
  2235. GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
  2236. ggml_cuda_op(src0, src1, dst, ggml_cuda_op_diag_mask_inf, true, true);
  2237. }
  2238. void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
  2239. GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
  2240. ggml_cuda_op(src0, src1, dst, ggml_cuda_op_soft_max, true, true);
  2241. }
  2242. void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
  2243. GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
  2244. ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, false); // FIXME flatten changes results
  2245. }
  2246. void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
  2247. (void) src0;
  2248. (void) src1;
  2249. (void) dst;
  2250. }
  2251. void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
  2252. int nrows = ggml_nrows(tensor);
  2253. const size_t nb1 = tensor->nb[1];
  2254. ggml_backend backend = tensor->backend;
  2255. struct ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu;
  2256. memset(extra, 0, sizeof(*extra));
  2257. for (int id = 0; id < g_device_count; ++id) {
  2258. if (backend == GGML_BACKEND_GPU && id != g_main_device) {
  2259. continue;
  2260. }
  2261. cudaSetDevice(id);
  2262. int row_low, row_high;
  2263. if (backend == GGML_BACKEND_GPU) {
  2264. row_low = 0;
  2265. row_high = nrows;
  2266. } else if (backend == GGML_BACKEND_GPU_SPLIT) {
  2267. row_low = id == 0 ? 0 : nrows*g_tensor_split[id];
  2268. row_high = id == g_device_count - 1 ? nrows : nrows*g_tensor_split[id + 1];
  2269. } else {
  2270. GGML_ASSERT(false);
  2271. }
  2272. if (row_low == row_high) {
  2273. continue;
  2274. }
  2275. int64_t nrows_split = row_high - row_low;
  2276. const size_t offset_split = row_low*nb1;
  2277. const size_t size = ggml_nbytes_split(tensor, nrows_split);
  2278. void * buf;
  2279. CUDA_CHECK(cudaMalloc(&buf, size));
  2280. void * buf_host = (char*)data + offset_split;
  2281. cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
  2282. extra->data_device[id] = buf;
  2283. if (backend == GGML_BACKEND_GPU_SPLIT) {
  2284. CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id], cudaEventDisableTiming));
  2285. }
  2286. }
  2287. tensor->extra = extra;
  2288. }
  2289. void ggml_cuda_free_data(struct ggml_tensor * tensor) {
  2290. if (tensor->backend != GGML_BACKEND_GPU && tensor->backend != GGML_BACKEND_GPU_SPLIT) {
  2291. return;
  2292. }
  2293. ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
  2294. for (int id = 0; id < g_device_count; ++id) {
  2295. if (extra->data_device[id] != nullptr) {
  2296. CUDA_CHECK(cudaSetDevice(id));
  2297. CUDA_CHECK(cudaFree(extra->data_device[id]));
  2298. }
  2299. if (extra->events[id] != nullptr) {
  2300. CUDA_CHECK(cudaSetDevice(id));
  2301. CUDA_CHECK(cudaEventDestroy(extra->events[id]));
  2302. }
  2303. }
  2304. delete extra;
  2305. }
  2306. void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace) {
  2307. if (scratch && g_scratch_size == 0) {
  2308. return;
  2309. }
  2310. // recursively assign CUDA buffers until a compute tensor is found
  2311. if (tensor->src0 != nullptr && tensor->src0->backend == GGML_BACKEND_CPU) {
  2312. const ggml_op src0_op = tensor->src0->op;
  2313. if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) {
  2314. ggml_cuda_assign_buffers_impl(tensor->src0, scratch, force_inplace);
  2315. }
  2316. }
  2317. if (tensor->op == GGML_OP_CPY && tensor->src1->backend == GGML_BACKEND_CPU) {
  2318. ggml_cuda_assign_buffers_impl(tensor->src1, scratch, force_inplace);
  2319. }
  2320. tensor->backend = GGML_BACKEND_GPU;
  2321. struct ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu;
  2322. memset(extra, 0, sizeof(*extra));
  2323. const bool inplace = (tensor->src0 != nullptr && tensor->src0->data == tensor->data) ||
  2324. tensor->op == GGML_OP_VIEW ||
  2325. force_inplace;
  2326. const size_t size = ggml_nbytes(tensor);
  2327. CUDA_CHECK(cudaSetDevice(g_main_device));
  2328. if (inplace && (tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT)) {
  2329. struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src0->extra;
  2330. char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
  2331. size_t offset = 0;
  2332. if (tensor->op == GGML_OP_VIEW) {
  2333. memcpy(&offset, tensor->opt[0]->data, sizeof(size_t));
  2334. }
  2335. extra->data_device[g_main_device] = src0_ddc + offset;
  2336. } else if (tensor->op == GGML_OP_CPY) {
  2337. struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu * ) tensor->src1->extra;
  2338. void * src1_ddv = src1_extra->data_device[g_main_device];
  2339. extra->data_device[g_main_device] = src1_ddv;
  2340. } else if (scratch) {
  2341. GGML_ASSERT(size <= g_scratch_size);
  2342. if (g_scratch_offset + size > g_scratch_size) {
  2343. g_scratch_offset = 0;
  2344. }
  2345. char * data = (char *) g_scratch_buffer;
  2346. if (data == nullptr) {
  2347. CUDA_CHECK(cudaMalloc(&data, g_scratch_size));
  2348. g_scratch_buffer = data;
  2349. }
  2350. extra->data_device[g_main_device] = data + g_scratch_offset;
  2351. g_scratch_offset += size;
  2352. GGML_ASSERT(g_scratch_offset <= g_scratch_size);
  2353. } else { // allocate new buffers outside of scratch
  2354. void * data;
  2355. CUDA_CHECK(cudaMalloc(&data, size));
  2356. CUDA_CHECK(cudaMemset(data, 0, size));
  2357. extra->data_device[g_main_device] = data;
  2358. }
  2359. tensor->extra = extra;
  2360. }
  2361. void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
  2362. ggml_cuda_assign_buffers_impl(tensor, true, false);
  2363. }
  2364. void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) {
  2365. ggml_cuda_assign_buffers_impl(tensor, false, false);
  2366. }
  2367. void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) {
  2368. ggml_cuda_assign_buffers_impl(tensor, false, true);
  2369. }
  2370. void ggml_cuda_set_main_device(int main_device) {
  2371. if (main_device >= g_device_count) {
  2372. fprintf(stderr, "warning: cannot set main_device=%d because there are only %d devices. Using device %d instead.\n",
  2373. main_device, g_device_count, g_main_device);
  2374. return;
  2375. }
  2376. g_main_device = main_device;
  2377. if (g_device_count > 1) {
  2378. cudaDeviceProp prop;
  2379. CUDA_CHECK(cudaGetDeviceProperties(&prop, g_main_device));
  2380. fprintf(stderr, "%s: using device %d (%s) as main device\n", __func__, g_main_device, prop.name);
  2381. }
  2382. }
  2383. void ggml_cuda_set_scratch_size(size_t scratch_size) {
  2384. g_scratch_size = scratch_size;
  2385. }
  2386. void ggml_cuda_free_scratch() {
  2387. if (g_scratch_buffer == nullptr) {
  2388. return;
  2389. }
  2390. CUDA_CHECK(cudaFree(g_scratch_buffer));
  2391. g_scratch_buffer = nullptr;
  2392. }
  2393. bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor){
  2394. ggml_cuda_func_t func;
  2395. const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
  2396. || (tensor->src0 != nullptr && (tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT))
  2397. || (tensor->src1 != nullptr && tensor->src1->backend == GGML_BACKEND_GPU);
  2398. switch (tensor->op) {
  2399. case GGML_OP_ADD:
  2400. if (!any_on_device) {
  2401. return false;
  2402. }
  2403. func = ggml_cuda_add;
  2404. break;
  2405. case GGML_OP_MUL:
  2406. if (!any_on_device) {
  2407. return false;
  2408. }
  2409. func = ggml_cuda_mul;
  2410. break;
  2411. case GGML_OP_SILU:
  2412. if (!any_on_device) {
  2413. return false;
  2414. }
  2415. func = ggml_cuda_silu;
  2416. break;
  2417. case GGML_OP_RMS_NORM:
  2418. if (!any_on_device) {
  2419. return false;
  2420. }
  2421. func = ggml_cuda_rms_norm;
  2422. break;
  2423. case GGML_OP_MUL_MAT:
  2424. if (!any_on_device && !ggml_cuda_can_mul_mat(tensor->src0, tensor->src1, tensor)) {
  2425. return false;
  2426. }
  2427. func = ggml_cuda_mul_mat;
  2428. break;
  2429. case GGML_OP_SCALE:
  2430. if (!any_on_device) {
  2431. return false;
  2432. }
  2433. func = ggml_cuda_scale;
  2434. break;
  2435. case GGML_OP_CPY:
  2436. if (!any_on_device) {
  2437. return false;
  2438. }
  2439. func = ggml_cuda_cpy;
  2440. break;
  2441. case GGML_OP_RESHAPE:
  2442. case GGML_OP_VIEW:
  2443. case GGML_OP_PERMUTE:
  2444. case GGML_OP_TRANSPOSE:
  2445. if (!any_on_device) {
  2446. return false;
  2447. }
  2448. func = ggml_cuda_nop;
  2449. break;
  2450. case GGML_OP_DIAG_MASK_INF:
  2451. if (!any_on_device) {
  2452. return false;
  2453. }
  2454. func = ggml_cuda_diag_mask_inf;
  2455. break;
  2456. case GGML_OP_SOFT_MAX:
  2457. if (!any_on_device) {
  2458. return false;
  2459. }
  2460. func = ggml_cuda_soft_max;
  2461. break;
  2462. case GGML_OP_ROPE:
  2463. if (!any_on_device) {
  2464. return false;
  2465. }
  2466. func = ggml_cuda_rope;
  2467. break;
  2468. default:
  2469. return false;
  2470. }
  2471. if (params->ith != 0) {
  2472. return true;
  2473. }
  2474. if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
  2475. return true;
  2476. }
  2477. func(tensor->src0, tensor->src1, tensor);
  2478. return true;
  2479. }