ggml-opencl.c 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. #include "ggml-opencl.h"
  2. #define CL_TARGET_OPENCL_VERSION 110
  3. #include <clblast_c.h>
  4. #include <stdlib.h>
  5. #include <stdio.h>
  6. #include <string.h>
  7. #include "ggml.h"
  8. #define MULTILINE_QUOTE(...) #__VA_ARGS__
  9. static const char * program_source = MULTILINE_QUOTE(
  10. typedef char int8_t;
  11. typedef uchar uint8_t;
  12. typedef int int32_t;
  13. typedef uint uint32_t;
  14. struct __attribute__ ((packed)) block_q4_0
  15. {
  16. half d;
  17. uint8_t qs[16]; /* QK4_0 / 2 */
  18. };
  19. struct __attribute__ ((packed)) block_q4_1
  20. {
  21. half d;
  22. half m;
  23. uint8_t qs[16]; /* QK4_1 / 2 */
  24. };
  25. struct __attribute__ ((packed)) block_q5_0
  26. {
  27. half d;
  28. uint32_t qh;
  29. uint8_t qs[16]; /* QK5_0 / 2 */
  30. };
  31. struct __attribute__ ((packed)) block_q5_1
  32. {
  33. half d;
  34. half m;
  35. uint32_t qh;
  36. uint8_t qs[16]; /* QK5_1 / 2 */
  37. };
  38. struct __attribute__ ((packed)) block_q8_0
  39. {
  40. half d;
  41. int8_t qs[32]; /* QK8_0 */
  42. };
  43. __kernel void dequantize_row_q4_0(__global struct block_q4_0* x, __global float* y) {
  44. const uint i = get_global_id(0) / 32; /* QK4_0 */
  45. const uint j = get_local_id(0);
  46. const float d = vload_half(0, (__global half*) &x[i].d);
  47. const int x0 = (x[i].qs[j] & 0xf) - 8;
  48. const int x1 = (x[i].qs[j] >> 4) - 8;
  49. y[i*32 + j + 0 ] = x0*d;
  50. y[i*32 + j + 16] = x1*d;
  51. }
  52. __kernel void dequantize_row_q4_1(__global struct block_q4_1* x, __global float* y) {
  53. const uint i = get_global_id(0) / 32; /* QK4_1 */
  54. const uint j = get_local_id(0);
  55. const float d = vload_half(0, (__global half*) &x[i].d);
  56. const float m = vload_half(0, (__global half*) &x[i].m);
  57. const int x0 = (x[i].qs[j] & 0xf);
  58. const int x1 = (x[i].qs[j] >> 4);
  59. y[i*32 + j + 0 ] = x0*d + m;
  60. y[i*32 + j + 16] = x1*d + m;
  61. }
  62. __kernel void dequantize_row_q5_0(__global struct block_q5_0* x, __global float* y) {
  63. const uint i = get_global_id(0) / 32; /* QK5_0 */
  64. const uint j = get_local_id(0);
  65. const float d = vload_half(0, (__global half*) &x[i].d);
  66. uint32_t qh = x[i].qh;
  67. const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
  68. const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
  69. const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16;
  70. const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
  71. y[i*32 + j + 0 ] = x0*d;
  72. y[i*32 + j + 16] = x1*d;
  73. }
  74. __kernel void dequantize_row_q5_1(__global struct block_q5_1* x, __global float* y) {
  75. const uint i = get_global_id(0) / 32; /* QK5_1 */
  76. const uint j = get_local_id(0);
  77. const float d = vload_half(0, (__global half*) &x[i].d);
  78. const float m = vload_half(0, (__global half*) &x[i].m);
  79. uint32_t qh = x[i].qh;
  80. const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
  81. const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
  82. const int x0 = (x[i].qs[j] & 0xf) | xh_0;
  83. const int x1 = (x[i].qs[j] >> 4) | xh_1;
  84. y[i*32 + j + 0 ] = x0*d + m;
  85. y[i*32 + j + 16] = x1*d + m;
  86. }
  87. __kernel void dequantize_row_q8_0(__global struct block_q8_0* x, __global float* y) {
  88. const uint i = get_global_id(0) / 32; /* QK8_0 */
  89. const uint j = get_local_id(0);
  90. const float d = vload_half(0, (__global half*) &x[i].d);
  91. y[i*32 + j] = x[i].qs[j]*d;
  92. }
  93. );
  94. #define CL_CHECK(err) \
  95. do { \
  96. cl_int err_ = (err); \
  97. if (err_ != CL_SUCCESS) { \
  98. fprintf(stderr, "ggml_opencl: %s error %d at %s:%d\n", \
  99. #err, err_, __FILE__, __LINE__); \
  100. exit(1); \
  101. } \
  102. } while (0)
  103. #define CLBLAST_CHECK(err) \
  104. do { \
  105. CLBlastStatusCode err_ = (err); \
  106. if (err_ != CLBlastSuccess) { \
  107. fprintf(stderr, "ggml_opencl: %s error %d at %s:%d\n", \
  108. #err, err_, __FILE__, __LINE__); \
  109. exit(1); \
  110. } \
  111. } while (0)
  112. static cl_platform_id platform;
  113. static cl_device_id device;
  114. static cl_context context;
  115. static cl_command_queue queue;
  116. static cl_program program;
  117. static cl_kernel kernel_q4_0, kernel_q4_1, kernel_q5_0, kernel_q5_1, kernel_q8_0;
  118. static cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c;
  119. static size_t cl_size_a = 0, cl_size_qb = 0, cl_size_b = 0, cl_size_c = 0;
  120. static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer) {
  121. cl_program p;
  122. char *program_log;
  123. size_t program_size, log_size;
  124. int err;
  125. program_size = strlen(program_buffer);
  126. p = clCreateProgramWithSource(ctx, 1, (const char**)&program_buffer, &program_size, &err);
  127. if(err < 0) {
  128. fprintf(stderr, "OpenCL error creating program");
  129. exit(1);
  130. }
  131. err = clBuildProgram(p, 0, NULL, NULL, NULL, NULL);
  132. if(err < 0) {
  133. clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size);
  134. program_log = (char*) malloc(log_size + 1);
  135. program_log[log_size] = '\0';
  136. clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, log_size + 1, program_log, NULL);
  137. printf("%s\n", program_log);
  138. free(program_log);
  139. exit(1);
  140. }
  141. return p;
  142. }
  143. void ggml_cl_init(void) {
  144. cl_int err = 0;
  145. struct cl_device;
  146. struct cl_platform {
  147. cl_platform_id id;
  148. unsigned number;
  149. char name[128];
  150. char vendor[128];
  151. struct cl_device * devices;
  152. unsigned n_devices;
  153. struct cl_device * default_device;
  154. };
  155. struct cl_device {
  156. struct cl_platform * platform;
  157. cl_device_id id;
  158. unsigned number;
  159. cl_device_type type;
  160. char name[128];
  161. };
  162. enum { NPLAT = 16, NDEV = 16 };
  163. struct cl_platform platforms[NPLAT];
  164. unsigned n_platforms = 0;
  165. struct cl_device devices[NDEV];
  166. unsigned n_devices = 0;
  167. struct cl_device * default_device = NULL;
  168. platform = NULL;
  169. device = NULL;
  170. cl_platform_id platform_ids[NPLAT];
  171. CL_CHECK(clGetPlatformIDs(NPLAT, platform_ids, &n_platforms));
  172. for (unsigned i = 0; i < n_platforms; i++) {
  173. struct cl_platform * p = &platforms[i];
  174. p->number = i;
  175. p->id = platform_ids[i];
  176. CL_CHECK(clGetPlatformInfo(p->id, CL_PLATFORM_NAME, sizeof(p->name), &p->name, NULL));
  177. CL_CHECK(clGetPlatformInfo(p->id, CL_PLATFORM_VENDOR, sizeof(p->vendor), &p->vendor, NULL));
  178. cl_device_id device_ids[NDEV];
  179. cl_int clGetDeviceIDsError = clGetDeviceIDs(p->id, CL_DEVICE_TYPE_ALL, NDEV, device_ids, &p->n_devices);
  180. if (clGetDeviceIDsError == CL_DEVICE_NOT_FOUND) {
  181. p->n_devices = 0;
  182. } else {
  183. CL_CHECK(clGetDeviceIDsError);
  184. }
  185. p->devices = p->n_devices > 0 ? &devices[n_devices] : NULL;
  186. p->default_device = NULL;
  187. for (unsigned j = 0; j < p->n_devices; j++) {
  188. struct cl_device * d = &devices[n_devices];
  189. d->number = n_devices++;
  190. d->id = device_ids[j];
  191. d->platform = p;
  192. CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_NAME, sizeof(d->name), &d->name, NULL));
  193. CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_TYPE, sizeof(d->type), &d->type, NULL));
  194. if (p->default_device == NULL && d->type == CL_DEVICE_TYPE_GPU) {
  195. p->default_device = d;
  196. }
  197. }
  198. if (default_device == NULL && p->default_device != NULL) {
  199. default_device = p->default_device;
  200. }
  201. }
  202. if (n_devices == 0) {
  203. fprintf(stderr, "ggml_opencl: could find any OpenCL devices.\n");
  204. exit(1);
  205. }
  206. char * user_platform_string = getenv("GGML_OPENCL_PLATFORM");
  207. char * user_device_string = getenv("GGML_OPENCL_DEVICE");
  208. int user_platform_number = -1;
  209. int user_device_number = -1;
  210. unsigned n;
  211. if (user_platform_string != NULL && sscanf(user_platform_string, " %u", &n) == 1 && n < n_platforms) {
  212. user_platform_number = (int)n;
  213. }
  214. if (user_device_string != NULL && sscanf(user_device_string, " %u", &n) == 1 && n < n_devices) {
  215. user_device_number = (int)n;
  216. }
  217. struct cl_device * selected_devices = devices;
  218. unsigned n_selected_devices = n_devices;
  219. if (user_platform_number == -1 && user_platform_string != NULL && user_platform_string[0] != 0) {
  220. for (unsigned i = 0; i < n_platforms; i++) {
  221. struct cl_platform * p = &platforms[i];
  222. if (strstr(p->name, user_platform_string) != NULL ||
  223. strstr(p->vendor, user_platform_string) != NULL) {
  224. user_platform_number = (int)i;
  225. break;
  226. }
  227. }
  228. if (user_platform_number == -1) {
  229. fprintf(stderr, "ggml_opencl: no platform matching '%s' was found.\n", user_platform_string);
  230. exit(1);
  231. }
  232. }
  233. if (user_platform_number != -1) {
  234. struct cl_platform * p = &platforms[user_platform_number];
  235. selected_devices = p->devices;
  236. n_selected_devices = p->n_devices;
  237. default_device = p->default_device;
  238. if (n_selected_devices == 0) {
  239. fprintf(stderr, "ggml_opencl: selected platform '%s' does not have any devices.\n", p->name);
  240. exit(1);
  241. }
  242. }
  243. if (user_device_number == -1 && user_device_string != NULL && user_device_string[0] != 0) {
  244. for (unsigned i = 0; i < n_selected_devices; i++) {
  245. struct cl_device * d = &selected_devices[i];
  246. if (strstr(d->name, user_device_string) != NULL) {
  247. user_device_number = d->number;
  248. break;
  249. }
  250. }
  251. if (user_device_number == -1) {
  252. fprintf(stderr, "ggml_opencl: no device matching '%s' was found.\n", user_device_string);
  253. exit(1);
  254. }
  255. }
  256. if (user_device_number != -1) {
  257. selected_devices = &devices[user_device_number];
  258. n_selected_devices = 1;
  259. default_device = &selected_devices[0];
  260. }
  261. GGML_ASSERT(n_selected_devices > 0);
  262. if (default_device == NULL) {
  263. default_device = &selected_devices[0];
  264. }
  265. fprintf(stderr, "ggml_opencl: selecting platform: '%s'\n", default_device->platform->name);
  266. fprintf(stderr, "ggml_opencl: selecting device: '%s'\n", default_device->name);
  267. if (default_device->type != CL_DEVICE_TYPE_GPU) {
  268. fprintf(stderr, "ggml_opencl: warning, not a GPU: '%s'.\n", default_device->name);
  269. }
  270. platform = default_device->platform->id;
  271. device = default_device->id;
  272. cl_context_properties properties[] = {
  273. (intptr_t)CL_CONTEXT_PLATFORM, (intptr_t)platform, 0
  274. };
  275. CL_CHECK((context = clCreateContext(properties, 1, &device, NULL, NULL, &err), err));
  276. CL_CHECK((queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err),
  277. (err != CL_INVALID_PROPERTY && err != CL_INVALID_VALUE ? err :
  278. (queue = clCreateCommandQueue(context, device, 0, &err), err)
  279. )));
  280. program = build_program_from_source(context, device, program_source);
  281. // Prepare dequantize kernels
  282. CL_CHECK((kernel_q4_0 = clCreateKernel(program, "dequantize_row_q4_0", &err), err));
  283. CL_CHECK((kernel_q4_1 = clCreateKernel(program, "dequantize_row_q4_1", &err), err));
  284. CL_CHECK((kernel_q5_0 = clCreateKernel(program, "dequantize_row_q5_0", &err), err));
  285. CL_CHECK((kernel_q5_1 = clCreateKernel(program, "dequantize_row_q5_1", &err), err));
  286. CL_CHECK((kernel_q8_0 = clCreateKernel(program, "dequantize_row_q8_0", &err), err));
  287. }
  288. static void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags, cl_mem* buf) {
  289. if (req_size <= *cur_size) {
  290. return;
  291. }
  292. // Reallocate buffer with enough space
  293. if (*cur_size > 0) {
  294. clReleaseMemObject(*buf);
  295. }
  296. cl_int err;
  297. CL_CHECK((*buf = clCreateBuffer(context, flags, req_size, NULL, &err), err));
  298. *cur_size = req_size;
  299. }
  300. void ggml_cl_sgemm_wrapper(
  301. const enum ggml_blas_order order, const enum ggml_blas_op trans_a, const enum ggml_blas_op trans_b,
  302. const int m, const int n, const int k,
  303. const float alpha, const void *host_a, const int lda,
  304. const float *host_b, const int ldb, const float beta,
  305. float *host_c, const int ldc, const int btype) {
  306. cl_kernel kernel;
  307. size_t global = n * k, local, size_qb;
  308. bool dequant;
  309. switch (btype) {
  310. case GGML_TYPE_F32:
  311. dequant = false;
  312. break;
  313. case GGML_TYPE_Q4_0:
  314. dequant = true;
  315. kernel = kernel_q4_0;
  316. local = 16;
  317. size_qb = global * (sizeof(ggml_fp16_t) + local) / 32;
  318. break;
  319. case GGML_TYPE_Q4_1:
  320. dequant = true;
  321. kernel = kernel_q4_1;
  322. local = 16;
  323. size_qb = global * (sizeof(ggml_fp16_t) * 2 + local) / 32;
  324. break;
  325. case GGML_TYPE_Q5_0:
  326. dequant = true;
  327. kernel = kernel_q5_0;
  328. local = 16;
  329. size_qb = global * (sizeof(ggml_fp16_t) + sizeof(uint32_t) + local) / 32;
  330. break;
  331. case GGML_TYPE_Q5_1:
  332. dequant = true;
  333. kernel = kernel_q5_1;
  334. local = 16;
  335. size_qb = global * (sizeof(ggml_fp16_t) * 2 + sizeof(uint32_t) + local) / 32;
  336. break;
  337. case GGML_TYPE_Q8_0:
  338. dequant = true;
  339. kernel = kernel_q8_0;
  340. local = 32;
  341. size_qb = global * (sizeof(ggml_fp16_t) + local) / 32;
  342. break;
  343. default:
  344. fprintf(stderr, "Error: Unsupported OpenCL btype %d\n", btype);
  345. abort();
  346. }
  347. const size_t size_a = m * k * sizeof(float);
  348. const size_t size_b = n * k * sizeof(float);
  349. const size_t size_c = m * n * sizeof(float);
  350. // Prepare buffers
  351. ggml_cl_malloc(size_a, &cl_size_a, CL_MEM_READ_ONLY, &cl_buffer_a);
  352. if (dequant) {
  353. ggml_cl_malloc(size_qb, &cl_size_qb, CL_MEM_READ_ONLY, &cl_buffer_qb);
  354. }
  355. ggml_cl_malloc(size_b, &cl_size_b, CL_MEM_READ_WRITE, &cl_buffer_b);
  356. ggml_cl_malloc(size_c, &cl_size_c, CL_MEM_WRITE_ONLY, &cl_buffer_c);
  357. cl_event ev_a, ev_qb, ev_b;
  358. if (dequant) {
  359. CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &cl_buffer_qb));
  360. CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_buffer_b));
  361. CL_CHECK(clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, size_qb, host_b, 0, NULL, &ev_qb));
  362. } else {
  363. CL_CHECK(clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, size_b, host_b, 0, NULL, &ev_b));
  364. }
  365. CL_CHECK(clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, size_a, host_a, 0, NULL, &ev_a));
  366. if (dequant) {
  367. CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, &ev_qb, &ev_b));
  368. CL_CHECK(clReleaseEvent(ev_qb));
  369. }
  370. CL_CHECK(clWaitForEvents(1, &ev_a));
  371. CL_CHECK(clWaitForEvents(1, &ev_b));
  372. CL_CHECK(clReleaseEvent(ev_a));
  373. CL_CHECK(clReleaseEvent(ev_b));
  374. cl_event ev_sgemm;
  375. CLBLAST_CHECK(CLBlastSgemm(
  376. (CLBlastLayout)order,
  377. (CLBlastTranspose)trans_a, (CLBlastTranspose)trans_b,
  378. m, n, k,
  379. alpha,
  380. cl_buffer_a, 0, lda,
  381. cl_buffer_b, 0, ldb,
  382. beta,
  383. cl_buffer_c, 0, ldc,
  384. &queue, &ev_sgemm));
  385. cl_event ev_c;
  386. CL_CHECK(clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, size_c, host_c, 1, &ev_sgemm, &ev_c));
  387. // Wait for completion
  388. CL_CHECK(clWaitForEvents(1, &ev_c));
  389. CL_CHECK(clReleaseEvent(ev_sgemm));
  390. CL_CHECK(clReleaseEvent(ev_c));
  391. }