norm.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. #include "norm.hpp"
  2. static void norm_f32(const float* x, float* dst, const int ncols, const float eps,
  3. const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) {
  4. const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
  5. item_ct1.get_local_id(1);
  6. const int tid = item_ct1.get_local_id(2);
  7. const int nthreads = item_ct1.get_local_range(2);
  8. const int nwarps = nthreads / WARP_SIZE;
  9. sycl::float2 mean_var = sycl::float2(0.f, 0.f);
  10. for (int col = tid; col < ncols; col += block_size) {
  11. const float xi = x[row * ncols + col];
  12. mean_var.x() += xi;
  13. mean_var.y() += xi * xi;
  14. }
  15. // sum up partial sums
  16. mean_var = warp_reduce_sum(mean_var, item_ct1);
  17. if (block_size > WARP_SIZE) {
  18. int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
  19. int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
  20. if (lane_id == 0) {
  21. s_sum[warp_id] = mean_var;
  22. }
  23. /*
  24. DPCT1118:0: SYCL group functions and algorithms must be encountered in
  25. converged control flow. You may need to adjust the code.
  26. */
  27. item_ct1.barrier(sycl::access::fence_space::local_space);
  28. mean_var = 0.f;
  29. size_t nreduce = nwarps / WARP_SIZE;
  30. for (size_t i = 0; i < nreduce; i += 1)
  31. {
  32. mean_var += s_sum[lane_id + i * WARP_SIZE];
  33. }
  34. mean_var = warp_reduce_sum(mean_var, item_ct1);
  35. }
  36. const float mean = mean_var.x() / ncols;
  37. const float var = mean_var.y() / ncols - mean * mean;
  38. const float inv_std = sycl::rsqrt(var + eps);
  39. for (int col = tid; col < ncols; col += block_size) {
  40. dst[row * ncols + col] = (x[row * ncols + col] - mean) * inv_std;
  41. }
  42. }
  43. static void group_norm_f32(const float* x, float* dst, const int group_size, const int ne_elements, const float eps,
  44. const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
  45. int start = item_ct1.get_group(2) * group_size;
  46. int end = start + group_size;
  47. const int nthreads = item_ct1.get_local_range(2);
  48. const int nwarps = nthreads / WARP_SIZE;
  49. start += item_ct1.get_local_id(2);
  50. size_t nreduce = nwarps / WARP_SIZE;
  51. if (end >= ne_elements) {
  52. end = ne_elements;
  53. }
  54. float tmp = 0.0f; // partial sum for thread in warp
  55. for (int j = start; j < end; j += block_size) {
  56. tmp += x[j];
  57. }
  58. tmp = warp_reduce_sum(tmp, item_ct1);
  59. if (block_size > WARP_SIZE) {
  60. int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
  61. int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
  62. if (lane_id == 0) {
  63. s_sum[warp_id] = tmp;
  64. }
  65. /*
  66. DPCT1118:1: SYCL group functions and algorithms must be encountered in
  67. converged control flow. You may need to adjust the code.
  68. */
  69. /*
  70. DPCT1065:54: Consider replacing sycl::nd_item::barrier() with
  71. sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
  72. better performance if there is no access to global memory.
  73. */
  74. item_ct1.barrier();
  75. tmp = 0.f;
  76. for (size_t i = 0; i < nreduce; i += 1)
  77. {
  78. tmp += s_sum[lane_id + i * WARP_SIZE];
  79. }
  80. tmp = warp_reduce_sum(tmp, item_ct1);
  81. }
  82. float mean = tmp / group_size;
  83. tmp = 0.0f;
  84. for (int j = start; j < end; j += block_size) {
  85. float xi = x[j] - mean;
  86. dst[j] = xi;
  87. tmp += xi * xi;
  88. }
  89. tmp = warp_reduce_sum(tmp, item_ct1);
  90. if (block_size > WARP_SIZE) {
  91. int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
  92. int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
  93. if (lane_id == 0) {
  94. s_sum[warp_id] = tmp;
  95. }
  96. /*
  97. DPCT1118:2: SYCL group functions and algorithms must be encountered in
  98. converged control flow. You may need to adjust the code.
  99. */
  100. /*
  101. DPCT1065:55: Consider replacing sycl::nd_item::barrier() with
  102. sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
  103. better performance if there is no access to global memory.
  104. */
  105. item_ct1.barrier();
  106. tmp = 0.f;
  107. for (size_t i = 0; i < nreduce; i += 1)
  108. {
  109. tmp += s_sum[lane_id + i * WARP_SIZE];
  110. }
  111. tmp = warp_reduce_sum(tmp, item_ct1);
  112. }
  113. float variance = tmp / group_size;
  114. float scale = sycl::rsqrt(variance + eps);
  115. for (int j = start; j < end; j += block_size) {
  116. dst[j] *= scale;
  117. }
  118. }
  119. static void rms_norm_f32(const float* x, float* dst, const int ncols, const float eps,
  120. const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
  121. const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
  122. item_ct1.get_local_id(1);
  123. const int tid = item_ct1.get_local_id(2);
  124. const int nthreads = item_ct1.get_local_range(2);
  125. const int nwarps = nthreads / WARP_SIZE;
  126. float tmp = 0.0f; // partial sum for thread in warp
  127. for (int col = tid; col < ncols; col += block_size) {
  128. const float xi = x[row * ncols + col];
  129. tmp += xi * xi;
  130. }
  131. // sum up partial sums
  132. tmp = warp_reduce_sum(tmp, item_ct1);
  133. if (block_size > WARP_SIZE) {
  134. int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
  135. int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
  136. if (lane_id == 0) {
  137. s_sum[warp_id] = tmp;
  138. }
  139. /*
  140. DPCT1118:3: SYCL group functions and algorithms must be encountered in
  141. converged control flow. You may need to adjust the code.
  142. */
  143. item_ct1.barrier(sycl::access::fence_space::local_space);
  144. size_t nreduce = nwarps / WARP_SIZE;
  145. tmp = 0.f;
  146. for (size_t i = 0; i < nreduce; i += 1)
  147. {
  148. tmp += s_sum[lane_id + i * WARP_SIZE];
  149. }
  150. tmp = warp_reduce_sum(tmp, item_ct1);
  151. }
  152. const float mean = tmp / ncols;
  153. const float scale = sycl::rsqrt(mean + eps);
  154. for (int col = tid; col < ncols; col += block_size) {
  155. dst[row * ncols + col] = scale * x[row * ncols + col];
  156. }
  157. }
  158. static void norm_f32_sycl(const float* x, float* dst, const int ncols,
  159. const int nrows, const float eps,
  160. queue_ptr stream, int device) {
  161. GGML_ASSERT(ncols % WARP_SIZE == 0);
  162. if (ncols < 1024) {
  163. const sycl::range<3> block_dims(1, 1, WARP_SIZE);
  164. stream->submit([&](sycl::handler& cgh) {
  165. cgh.parallel_for(
  166. sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
  167. block_dims),
  168. [=](sycl::nd_item<3> item_ct1)
  169. [[intel::reqd_sub_group_size(WARP_SIZE)]] {
  170. norm_f32(x, dst, ncols, eps, item_ct1,
  171. nullptr, WARP_SIZE);
  172. });
  173. });
  174. }
  175. else {
  176. const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
  177. assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
  178. const sycl::range<3> block_dims(1, 1, work_group_size);
  179. /*
  180. DPCT1049:17: The work-group size passed to the SYCL kernel may exceed
  181. the limit. To get the device limit, query
  182. info::device::max_work_group_size. Adjust the work-group size if needed.
  183. */
  184. stream->submit([&](sycl::handler& cgh) {
  185. sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
  186. sycl::range<1>(work_group_size / WARP_SIZE), cgh);
  187. cgh.parallel_for(
  188. sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
  189. block_dims),
  190. [=](sycl::nd_item<3> item_ct1)
  191. [[intel::reqd_sub_group_size(WARP_SIZE)]] {
  192. norm_f32(x, dst, ncols, eps, item_ct1,
  193. get_pointer(s_sum_acc_ct1), work_group_size);
  194. });
  195. });
  196. }
  197. }
  198. static void group_norm_f32_sycl(const float* x, float* dst,
  199. const int num_groups, const float eps, const int group_size,
  200. const int ne_elements, queue_ptr stream, int device) {
  201. if (group_size < 1024) {
  202. const sycl::range<3> block_dims(1, 1, WARP_SIZE);
  203. stream->submit([&](sycl::handler& cgh) {
  204. const float eps_ct4 = eps;
  205. cgh.parallel_for(
  206. sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
  207. block_dims),
  208. [=](sycl::nd_item<3> item_ct1)
  209. [[intel::reqd_sub_group_size(WARP_SIZE)]] {
  210. group_norm_f32(
  211. x, dst, group_size, ne_elements, eps_ct4, item_ct1,
  212. nullptr, WARP_SIZE);
  213. });
  214. });
  215. }
  216. else {
  217. const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
  218. assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
  219. const sycl::range<3> block_dims(1, 1, work_group_size);
  220. /*
  221. DPCT1049:18: The work-group size passed to the SYCL kernel may exceed
  222. the limit. To get the device limit, query
  223. info::device::max_work_group_size. Adjust the work-group size if needed.
  224. */
  225. stream->submit([&](sycl::handler& cgh) {
  226. sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
  227. cgh);
  228. const float eps_ct4 = eps;
  229. cgh.parallel_for(
  230. sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
  231. block_dims),
  232. [=](sycl::nd_item<3> item_ct1)
  233. [[intel::reqd_sub_group_size(WARP_SIZE)]] {
  234. group_norm_f32(x, dst, group_size, ne_elements,
  235. eps_ct4, item_ct1,
  236. get_pointer(s_sum_acc_ct1), work_group_size);
  237. });
  238. });
  239. }
  240. }
  241. static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
  242. const int nrows, const float eps,
  243. queue_ptr stream, int device) {
  244. GGML_ASSERT(ncols % WARP_SIZE == 0);
  245. // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
  246. if (ncols < 1024) {
  247. const sycl::range<3> block_dims(1, 1, WARP_SIZE);
  248. stream->submit([&](sycl::handler& cgh) {
  249. cgh.parallel_for(
  250. sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
  251. block_dims),
  252. [=](sycl::nd_item<3> item_ct1)
  253. [[intel::reqd_sub_group_size(WARP_SIZE)]] {
  254. rms_norm_f32(x, dst, ncols, eps, item_ct1,
  255. nullptr, WARP_SIZE);
  256. });
  257. });
  258. }
  259. else {
  260. const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
  261. assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
  262. const sycl::range<3> block_dims(1, 1, work_group_size);
  263. /*
  264. DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
  265. the limit. To get the device limit, query
  266. info::device::max_work_group_size. Adjust the work-group size if needed.
  267. */
  268. stream->submit([&](sycl::handler& cgh) {
  269. sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
  270. cgh);
  271. cgh.parallel_for(
  272. sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
  273. block_dims),
  274. [=](sycl::nd_item<3> item_ct1)
  275. [[intel::reqd_sub_group_size(WARP_SIZE)]] {
  276. rms_norm_f32(x, dst, ncols, eps, item_ct1,
  277. get_pointer(s_sum_acc_ct1), work_group_size);
  278. });
  279. });
  280. }
  281. }
  282. void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1,
  283. ggml_tensor* dst, const float* src0_dd,
  284. const float* src1_dd, float* dst_dd,
  285. const queue_ptr& main_stream) {
  286. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  287. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  288. const int64_t ne00 = src0->ne[0];
  289. const int64_t nrows = ggml_nrows(src0);
  290. float eps;
  291. memcpy(&eps, dst->op_params, sizeof(float));
  292. norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
  293. (void)src1;
  294. (void)dst;
  295. (void)src1_dd;
  296. }
  297. void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
  298. const ggml_tensor* src1, ggml_tensor* dst,
  299. const float* src0_dd, const float* src1_dd,
  300. float* dst_dd,
  301. const queue_ptr& main_stream) {
  302. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  303. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  304. int num_groups = dst->op_params[0];
  305. float eps;
  306. memcpy(&eps, dst->op_params + 1, sizeof(float));
  307. int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
  308. group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device);
  309. (void)src1;
  310. (void)dst;
  311. (void)src1_dd;
  312. GGML_UNUSED(ctx);
  313. }
  314. void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
  315. const ggml_tensor* src1, ggml_tensor* dst,
  316. const float* src0_dd, const float* src1_dd,
  317. float* dst_dd,
  318. const queue_ptr& main_stream) {
  319. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  320. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  321. const int64_t ne00 = src0->ne[0];
  322. const int64_t nrows = ggml_nrows(src0);
  323. float eps;
  324. memcpy(&eps, dst->op_params, sizeof(float));
  325. rms_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
  326. (void)src1;
  327. (void)dst;
  328. (void)src1_dd;
  329. }