norm.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. #include "norm.hpp"
  2. static void norm_f32(const float* x, float* dst, const int ncols, const float eps,
  3. const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) {
  4. const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
  5. item_ct1.get_local_id(1);
  6. const int tid = item_ct1.get_local_id(2);
  7. const int nthreads = item_ct1.get_local_range(2);
  8. const int nwarps = nthreads / WARP_SIZE;
  9. sycl::float2 mean_var = sycl::float2(0.f, 0.f);
  10. for (int col = tid; col < ncols; col += block_size) {
  11. const float xi = x[row * ncols + col];
  12. mean_var.x() += xi;
  13. mean_var.y() += xi * xi;
  14. }
  15. // sum up partial sums
  16. mean_var = warp_reduce_sum(mean_var, item_ct1);
  17. if (block_size > WARP_SIZE) {
  18. int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
  19. int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
  20. if (lane_id == 0) {
  21. s_sum[warp_id] = mean_var;
  22. }
  23. /*
  24. DPCT1118:0: SYCL group functions and algorithms must be encountered in
  25. converged control flow. You may need to adjust the code.
  26. */
  27. item_ct1.barrier(sycl::access::fence_space::local_space);
  28. mean_var = 0.f;
  29. size_t nreduce = nwarps / WARP_SIZE;
  30. for (size_t i = 0; i < nreduce; i += 1)
  31. {
  32. mean_var += s_sum[lane_id + i * WARP_SIZE];
  33. }
  34. mean_var = warp_reduce_sum(mean_var, item_ct1);
  35. }
  36. const float mean = mean_var.x() / ncols;
  37. const float var = mean_var.y() / ncols - mean * mean;
  38. const float inv_std = sycl::rsqrt(var + eps);
  39. for (int col = tid; col < ncols; col += block_size) {
  40. dst[row * ncols + col] = (x[row * ncols + col] - mean) * inv_std;
  41. }
  42. }
  43. static void group_norm_f32(const float* x, float* dst, const int group_size, const int ne_elements, const float eps,
  44. const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
  45. int start = item_ct1.get_group(2) * group_size;
  46. int end = start + group_size;
  47. const int nthreads = item_ct1.get_local_range(2);
  48. const int nwarps = nthreads / WARP_SIZE;
  49. start += item_ct1.get_local_id(2);
  50. size_t nreduce = nwarps / WARP_SIZE;
  51. if (end >= ne_elements) {
  52. end = ne_elements;
  53. }
  54. float tmp = 0.0f; // partial sum for thread in warp
  55. for (int j = start; j < end; j += block_size) {
  56. tmp += x[j];
  57. }
  58. tmp = warp_reduce_sum(tmp, item_ct1);
  59. if (block_size > WARP_SIZE) {
  60. int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
  61. int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
  62. if (lane_id == 0) {
  63. s_sum[warp_id] = tmp;
  64. }
  65. /*
  66. DPCT1118:1: SYCL group functions and algorithms must be encountered in
  67. converged control flow. You may need to adjust the code.
  68. */
  69. /*
  70. DPCT1065:54: Consider replacing sycl::nd_item::barrier() with
  71. sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
  72. better performance if there is no access to global memory.
  73. */
  74. item_ct1.barrier();
  75. tmp = 0.f;
  76. for (size_t i = 0; i < nreduce; i += 1)
  77. {
  78. tmp += s_sum[lane_id + i * WARP_SIZE];
  79. }
  80. tmp = warp_reduce_sum(tmp, item_ct1);
  81. }
  82. float mean = tmp / group_size;
  83. tmp = 0.0f;
  84. for (int j = start; j < end; j += block_size) {
  85. float xi = x[j] - mean;
  86. dst[j] = xi;
  87. tmp += xi * xi;
  88. }
  89. tmp = warp_reduce_sum(tmp, item_ct1);
  90. if (block_size > WARP_SIZE) {
  91. int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
  92. int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
  93. if (lane_id == 0) {
  94. s_sum[warp_id] = tmp;
  95. }
  96. /*
  97. DPCT1118:2: SYCL group functions and algorithms must be encountered in
  98. converged control flow. You may need to adjust the code.
  99. */
  100. /*
  101. DPCT1065:55: Consider replacing sycl::nd_item::barrier() with
  102. sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
  103. better performance if there is no access to global memory.
  104. */
  105. item_ct1.barrier();
  106. tmp = 0.f;
  107. for (size_t i = 0; i < nreduce; i += 1)
  108. {
  109. tmp += s_sum[lane_id + i * WARP_SIZE];
  110. }
  111. tmp = warp_reduce_sum(tmp, item_ct1);
  112. }
  113. float variance = tmp / group_size;
  114. float scale = sycl::rsqrt(variance + eps);
  115. for (int j = start; j < end; j += block_size) {
  116. dst[j] *= scale;
  117. }
  118. }
  119. static void rms_norm_f32(const float* x, float* dst, const int ncols, const float eps,
  120. const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
  121. const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
  122. item_ct1.get_local_id(1);
  123. const int tid = item_ct1.get_local_id(2);
  124. const int nthreads = item_ct1.get_local_range(2);
  125. const int nwarps = nthreads / WARP_SIZE;
  126. float tmp = 0.0f; // partial sum for thread in warp
  127. for (int col = tid; col < ncols; col += block_size) {
  128. const float xi = x[row * ncols + col];
  129. tmp += xi * xi;
  130. }
  131. // sum up partial sums
  132. tmp = warp_reduce_sum(tmp, item_ct1);
  133. if (block_size > WARP_SIZE) {
  134. int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
  135. int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
  136. if (lane_id == 0) {
  137. s_sum[warp_id] = tmp;
  138. }
  139. /*
  140. DPCT1118:3: SYCL group functions and algorithms must be encountered in
  141. converged control flow. You may need to adjust the code.
  142. */
  143. item_ct1.barrier(sycl::access::fence_space::local_space);
  144. size_t nreduce = nwarps / WARP_SIZE;
  145. tmp = 0.f;
  146. for (size_t i = 0; i < nreduce; i += 1)
  147. {
  148. tmp += s_sum[lane_id + i * WARP_SIZE];
  149. }
  150. tmp = warp_reduce_sum(tmp, item_ct1);
  151. }
  152. const float mean = tmp / ncols;
  153. const float scale = sycl::rsqrt(mean + eps);
  154. for (int col = tid; col < ncols; col += block_size) {
  155. dst[row * ncols + col] = scale * x[row * ncols + col];
  156. }
  157. }
  158. static void l2_norm_f32(const float* x, float* dst, const int ncols, const float eps,
  159. const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
  160. const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
  161. item_ct1.get_local_id(1);
  162. const int tid = item_ct1.get_local_id(2);
  163. const int nthreads = item_ct1.get_local_range(2);
  164. const int nwarps = nthreads / WARP_SIZE;
  165. float tmp = 0.0f; // partial sum for thread in warp
  166. for (int col = tid; col < ncols; col += block_size) {
  167. const float xi = x[row * ncols + col];
  168. tmp += xi * xi;
  169. }
  170. // sum up partial sums
  171. tmp = warp_reduce_sum(tmp, item_ct1);
  172. if (block_size > WARP_SIZE) {
  173. int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
  174. int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
  175. if (lane_id == 0) {
  176. s_sum[warp_id] = tmp;
  177. }
  178. /*
  179. DPCT1118:3: SYCL group functions and algorithms must be encountered in
  180. converged control flow. You may need to adjust the code.
  181. */
  182. item_ct1.barrier(sycl::access::fence_space::local_space);
  183. size_t nreduce = nwarps / WARP_SIZE;
  184. tmp = 0.f;
  185. for (size_t i = 0; i < nreduce; i += 1)
  186. {
  187. tmp += s_sum[lane_id + i * WARP_SIZE];
  188. }
  189. tmp = warp_reduce_sum(tmp, item_ct1);
  190. }
  191. const float scale = sycl::rsqrt(sycl::max(tmp, eps * eps));
  192. for (int col = tid; col < ncols; col += block_size) {
  193. dst[row * ncols + col] = scale * x[row * ncols + col];
  194. }
  195. }
  196. static void norm_f32_sycl(const float* x, float* dst, const int ncols,
  197. const int nrows, const float eps,
  198. queue_ptr stream, int device) {
  199. GGML_ASSERT(ncols % WARP_SIZE == 0);
  200. if (ncols < 1024) {
  201. const sycl::range<3> block_dims(1, 1, WARP_SIZE);
  202. stream->submit([&](sycl::handler& cgh) {
  203. cgh.parallel_for(
  204. sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
  205. block_dims),
  206. [=](sycl::nd_item<3> item_ct1)
  207. [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
  208. norm_f32(x, dst, ncols, eps, item_ct1,
  209. nullptr, WARP_SIZE);
  210. });
  211. });
  212. }
  213. else {
  214. const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
  215. assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
  216. const sycl::range<3> block_dims(1, 1, work_group_size);
  217. /*
  218. DPCT1049:17: The work-group size passed to the SYCL kernel may exceed
  219. the limit. To get the device limit, query
  220. info::device::max_work_group_size. Adjust the work-group size if needed.
  221. */
  222. stream->submit([&](sycl::handler& cgh) {
  223. sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
  224. sycl::range<1>(work_group_size / WARP_SIZE), cgh);
  225. cgh.parallel_for(
  226. sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
  227. block_dims),
  228. [=](sycl::nd_item<3> item_ct1)
  229. [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
  230. norm_f32(x, dst, ncols, eps, item_ct1,
  231. get_pointer(s_sum_acc_ct1), work_group_size);
  232. });
  233. });
  234. }
  235. }
  236. static void group_norm_f32_sycl(const float* x, float* dst,
  237. const int num_groups, const float eps, const int group_size,
  238. const int ne_elements, queue_ptr stream, int device) {
  239. if (group_size < 1024) {
  240. const sycl::range<3> block_dims(1, 1, WARP_SIZE);
  241. stream->submit([&](sycl::handler& cgh) {
  242. const float eps_ct4 = eps;
  243. cgh.parallel_for(
  244. sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
  245. block_dims),
  246. [=](sycl::nd_item<3> item_ct1)
  247. [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
  248. group_norm_f32(
  249. x, dst, group_size, ne_elements, eps_ct4, item_ct1,
  250. nullptr, WARP_SIZE);
  251. });
  252. });
  253. }
  254. else {
  255. const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
  256. assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
  257. const sycl::range<3> block_dims(1, 1, work_group_size);
  258. /*
  259. DPCT1049:18: The work-group size passed to the SYCL kernel may exceed
  260. the limit. To get the device limit, query
  261. info::device::max_work_group_size. Adjust the work-group size if needed.
  262. */
  263. stream->submit([&](sycl::handler& cgh) {
  264. sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
  265. cgh);
  266. const float eps_ct4 = eps;
  267. cgh.parallel_for(
  268. sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
  269. block_dims),
  270. [=](sycl::nd_item<3> item_ct1)
  271. [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
  272. group_norm_f32(x, dst, group_size, ne_elements,
  273. eps_ct4, item_ct1,
  274. get_pointer(s_sum_acc_ct1), work_group_size);
  275. });
  276. });
  277. }
  278. }
  279. static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
  280. const int nrows, const float eps,
  281. queue_ptr stream, int device) {
  282. GGML_ASSERT(ncols % WARP_SIZE == 0);
  283. // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
  284. if (ncols < 1024) {
  285. const sycl::range<3> block_dims(1, 1, WARP_SIZE);
  286. stream->submit([&](sycl::handler& cgh) {
  287. cgh.parallel_for(
  288. sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
  289. block_dims),
  290. [=](sycl::nd_item<3> item_ct1)
  291. [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
  292. rms_norm_f32(x, dst, ncols, eps, item_ct1,
  293. nullptr, WARP_SIZE);
  294. });
  295. });
  296. }
  297. else {
  298. const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
  299. assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
  300. const sycl::range<3> block_dims(1, 1, work_group_size);
  301. /*
  302. DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
  303. the limit. To get the device limit, query
  304. info::device::max_work_group_size. Adjust the work-group size if needed.
  305. */
  306. stream->submit([&](sycl::handler& cgh) {
  307. sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
  308. cgh);
  309. cgh.parallel_for(
  310. sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
  311. block_dims),
  312. [=](sycl::nd_item<3> item_ct1)
  313. [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
  314. rms_norm_f32(x, dst, ncols, eps, item_ct1,
  315. get_pointer(s_sum_acc_ct1), work_group_size);
  316. });
  317. });
  318. }
  319. }
  320. static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
  321. const int nrows, const float eps,
  322. queue_ptr stream, int device) {
  323. GGML_ASSERT(ncols % WARP_SIZE == 0);
  324. // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
  325. if (ncols < 1024) {
  326. const sycl::range<3> block_dims(1, 1, WARP_SIZE);
  327. stream->submit([&](sycl::handler& cgh) {
  328. cgh.parallel_for(
  329. sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
  330. block_dims),
  331. [=](sycl::nd_item<3> item_ct1)
  332. [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
  333. l2_norm_f32(x, dst, ncols, eps, item_ct1,
  334. nullptr, WARP_SIZE);
  335. });
  336. });
  337. }
  338. else {
  339. const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
  340. assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
  341. const sycl::range<3> block_dims(1, 1, work_group_size);
  342. /*
  343. DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
  344. the limit. To get the device limit, query
  345. info::device::max_work_group_size. Adjust the work-group size if needed.
  346. */
  347. stream->submit([&](sycl::handler& cgh) {
  348. sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
  349. cgh);
  350. cgh.parallel_for(
  351. sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
  352. block_dims),
  353. [=](sycl::nd_item<3> item_ct1)
  354. [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
  355. l2_norm_f32(x, dst, ncols, eps, item_ct1,
  356. get_pointer(s_sum_acc_ct1), work_group_size);
  357. });
  358. });
  359. }
  360. }
  361. void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
  362. GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
  363. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  364. const int64_t ne00 = dst->src[0]->ne[0];
  365. const int64_t nrows = ggml_nrows(dst->src[0]);
  366. dpct::queue_ptr main_stream = ctx.stream();
  367. SYCL_CHECK(ggml_sycl_set_device(ctx.device));
  368. const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
  369. float * dst_dd = static_cast<float *>(dst->data);
  370. float eps;
  371. memcpy(&eps, dst->op_params, sizeof(float));
  372. norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
  373. }
  374. void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
  375. GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
  376. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  377. int num_groups = dst->op_params[0];
  378. dpct::queue_ptr main_stream = ctx.stream();
  379. SYCL_CHECK(ggml_sycl_set_device(ctx.device));
  380. const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
  381. float * dst_dd = static_cast<float *>(dst->data);
  382. float eps;
  383. memcpy(&eps, dst->op_params + 1, sizeof(float));
  384. int group_size = dst->src[0]->ne[0] * dst->src[0]->ne[1] * ((dst->src[0]->ne[2] + num_groups - 1) / num_groups);
  385. group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, dst->src[0]->ne[0] * dst->src[0]->ne[1] * dst->src[0]->ne[2], main_stream, ctx.device);
  386. }
  387. void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
  388. GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
  389. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  390. const int64_t ne00 = dst->src[0]->ne[0];
  391. const int64_t nrows = ggml_nrows(dst->src[0]);
  392. dpct::queue_ptr main_stream = ctx.stream();
  393. SYCL_CHECK(ggml_sycl_set_device(ctx.device));
  394. const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
  395. float * dst_dd = static_cast<float *>(dst->data);
  396. float eps;
  397. memcpy(&eps, dst->op_params, sizeof(float));
  398. rms_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
  399. }
  400. void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
  401. GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
  402. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  403. dpct::queue_ptr main_stream = ctx.stream();
  404. SYCL_CHECK(ggml_sycl_set_device(ctx.device));
  405. const int64_t ne00 = dst->src[0]->ne[0];
  406. const int64_t nrows = ggml_nrows(dst->src[0]);
  407. const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
  408. float * dst_dd = static_cast<float *>(dst->data);
  409. float eps;
  410. memcpy(&eps, dst->op_params, sizeof(float));
  411. l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
  412. }