norm.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. #include "norm.hpp"
  2. static void norm_f32(const float* x, float* dst, const int ncols, const float eps,
  3. const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) {
  4. const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
  5. item_ct1.get_local_id(1);
  6. const int tid = item_ct1.get_local_id(2);
  7. const int nthreads = item_ct1.get_local_range(2);
  8. const int nwarps = nthreads / WARP_SIZE;
  9. assert(nwarps % WARP_SIZE == 0);
  10. sycl::float2 mean_var = sycl::float2(0.f, 0.f);
  11. for (int col = tid; col < ncols; col += block_size) {
  12. const float xi = x[row * ncols + col];
  13. mean_var.x() += xi;
  14. mean_var.y() += xi * xi;
  15. }
  16. // sum up partial sums
  17. mean_var = warp_reduce_sum(mean_var, item_ct1);
  18. if (block_size > WARP_SIZE) {
  19. int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
  20. int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
  21. if (lane_id == 0) {
  22. s_sum[warp_id] = mean_var;
  23. }
  24. /*
  25. DPCT1118:0: SYCL group functions and algorithms must be encountered in
  26. converged control flow. You may need to adjust the code.
  27. */
  28. item_ct1.barrier(sycl::access::fence_space::local_space);
  29. mean_var = 0.f;
  30. int nreduce = nwarps / WARP_SIZE;
  31. for (size_t i = 0; i < nreduce; i += 1)
  32. {
  33. mean_var += s_sum[lane_id + i * WARP_SIZE];
  34. }
  35. mean_var = warp_reduce_sum(mean_var, item_ct1);
  36. }
  37. const float mean = mean_var.x() / ncols;
  38. const float var = mean_var.y() / ncols - mean * mean;
  39. const float inv_std = sycl::rsqrt(var + eps);
  40. for (int col = tid; col < ncols; col += block_size) {
  41. dst[row * ncols + col] = (x[row * ncols + col] - mean) * inv_std;
  42. }
  43. }
  44. static void group_norm_f32(const float* x, float* dst, const int group_size, const int ne_elements, const float eps,
  45. const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
  46. int start = item_ct1.get_group(2) * group_size;
  47. int end = start + group_size;
  48. const int nthreads = item_ct1.get_local_range(2);
  49. const int nwarps = nthreads / WARP_SIZE;
  50. assert(nwarps % WARP_SIZE == 0);
  51. start += item_ct1.get_local_id(2);
  52. int nreduce = nwarps / WARP_SIZE;
  53. if (end >= ne_elements) {
  54. end = ne_elements;
  55. }
  56. float tmp = 0.0f; // partial sum for thread in warp
  57. for (int j = start; j < end; j += block_size) {
  58. tmp += x[j];
  59. }
  60. tmp = warp_reduce_sum(tmp, item_ct1);
  61. if (block_size > WARP_SIZE) {
  62. int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
  63. int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
  64. if (lane_id == 0) {
  65. s_sum[warp_id] = tmp;
  66. }
  67. /*
  68. DPCT1118:1: SYCL group functions and algorithms must be encountered in
  69. converged control flow. You may need to adjust the code.
  70. */
  71. /*
  72. DPCT1065:54: Consider replacing sycl::nd_item::barrier() with
  73. sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
  74. better performance if there is no access to global memory.
  75. */
  76. item_ct1.barrier();
  77. tmp = 0.f;
  78. for (size_t i = 0; i < nreduce; i += 1)
  79. {
  80. tmp += s_sum[lane_id + i * WARP_SIZE];
  81. }
  82. tmp = warp_reduce_sum(tmp, item_ct1);
  83. }
  84. float mean = tmp / group_size;
  85. tmp = 0.0f;
  86. for (int j = start; j < end; j += block_size) {
  87. float xi = x[j] - mean;
  88. dst[j] = xi;
  89. tmp += xi * xi;
  90. }
  91. tmp = warp_reduce_sum(tmp, item_ct1);
  92. if (block_size > WARP_SIZE) {
  93. int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
  94. int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
  95. if (lane_id == 0) {
  96. s_sum[warp_id] = tmp;
  97. }
  98. /*
  99. DPCT1118:2: SYCL group functions and algorithms must be encountered in
  100. converged control flow. You may need to adjust the code.
  101. */
  102. /*
  103. DPCT1065:55: Consider replacing sycl::nd_item::barrier() with
  104. sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
  105. better performance if there is no access to global memory.
  106. */
  107. item_ct1.barrier();
  108. tmp = 0.f;
  109. for (size_t i = 0; i < nreduce; i += 1)
  110. {
  111. tmp += s_sum[lane_id + i * WARP_SIZE];
  112. }
  113. tmp = warp_reduce_sum(tmp, item_ct1);
  114. }
  115. float variance = tmp / group_size;
  116. float scale = sycl::rsqrt(variance + eps);
  117. for (int j = start; j < end; j += block_size) {
  118. dst[j] *= scale;
  119. }
  120. }
  121. static void rms_norm_f32(const float* x, float* dst, const int ncols, const float eps,
  122. const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
  123. const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
  124. item_ct1.get_local_id(1);
  125. const int tid = item_ct1.get_local_id(2);
  126. const int nthreads = item_ct1.get_local_range(2);
  127. const int nwarps = nthreads / WARP_SIZE;
  128. assert(nwarps % WARP_SIZE == 0);
  129. float tmp = 0.0f; // partial sum for thread in warp
  130. for (int col = tid; col < ncols; col += block_size) {
  131. const float xi = x[row * ncols + col];
  132. tmp += xi * xi;
  133. }
  134. // sum up partial sums
  135. tmp = warp_reduce_sum(tmp, item_ct1);
  136. if (block_size > WARP_SIZE) {
  137. int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
  138. int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
  139. if (lane_id == 0) {
  140. s_sum[warp_id] = tmp;
  141. }
  142. /*
  143. DPCT1118:3: SYCL group functions and algorithms must be encountered in
  144. converged control flow. You may need to adjust the code.
  145. */
  146. item_ct1.barrier(sycl::access::fence_space::local_space);
  147. int nreduce = nwarps / WARP_SIZE;
  148. tmp = 0.f;
  149. for (size_t i = 0; i < nreduce; i += 1)
  150. {
  151. tmp += s_sum[lane_id + i * WARP_SIZE];
  152. }
  153. tmp = warp_reduce_sum(tmp, item_ct1);
  154. }
  155. const float mean = tmp / ncols;
  156. const float scale = sycl::rsqrt(mean + eps);
  157. for (int col = tid; col < ncols; col += block_size) {
  158. dst[row * ncols + col] = scale * x[row * ncols + col];
  159. }
  160. }
  161. static void norm_f32_sycl(const float* x, float* dst, const int ncols,
  162. const int nrows, const float eps,
  163. queue_ptr stream, int device) {
  164. GGML_ASSERT(ncols % WARP_SIZE == 0);
  165. if (ncols < 1024) {
  166. const sycl::range<3> block_dims(1, 1, WARP_SIZE);
  167. stream->submit([&](sycl::handler& cgh) {
  168. cgh.parallel_for(
  169. sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
  170. block_dims),
  171. [=](sycl::nd_item<3> item_ct1)
  172. [[intel::reqd_sub_group_size(WARP_SIZE)]] {
  173. norm_f32(x, dst, ncols, eps, item_ct1,
  174. nullptr, WARP_SIZE);
  175. });
  176. });
  177. }
  178. else {
  179. const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
  180. const sycl::range<3> block_dims(1, 1, work_group_size);
  181. /*
  182. DPCT1049:17: The work-group size passed to the SYCL kernel may exceed
  183. the limit. To get the device limit, query
  184. info::device::max_work_group_size. Adjust the work-group size if needed.
  185. */
  186. stream->submit([&](sycl::handler& cgh) {
  187. sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
  188. sycl::range<1>(work_group_size / WARP_SIZE), cgh);
  189. cgh.parallel_for(
  190. sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
  191. block_dims),
  192. [=](sycl::nd_item<3> item_ct1)
  193. [[intel::reqd_sub_group_size(WARP_SIZE)]] {
  194. norm_f32(x, dst, ncols, eps, item_ct1,
  195. get_pointer(s_sum_acc_ct1), work_group_size);
  196. });
  197. });
  198. }
  199. }
  200. static void group_norm_f32_sycl(const float* x, float* dst,
  201. const int num_groups, const int group_size,
  202. const int ne_elements, queue_ptr stream, int device) {
  203. static const float eps = 1e-6f;
  204. if (group_size < 1024) {
  205. const sycl::range<3> block_dims(1, 1, WARP_SIZE);
  206. stream->submit([&](sycl::handler& cgh) {
  207. const float eps_ct4 = eps;
  208. cgh.parallel_for(
  209. sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
  210. block_dims),
  211. [=](sycl::nd_item<3> item_ct1)
  212. [[intel::reqd_sub_group_size(WARP_SIZE)]] {
  213. group_norm_f32(
  214. x, dst, group_size, ne_elements, eps_ct4, item_ct1,
  215. nullptr, WARP_SIZE);
  216. });
  217. });
  218. }
  219. else {
  220. const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
  221. const sycl::range<3> block_dims(1, 1, work_group_size);
  222. /*
  223. DPCT1049:18: The work-group size passed to the SYCL kernel may exceed
  224. the limit. To get the device limit, query
  225. info::device::max_work_group_size. Adjust the work-group size if needed.
  226. */
  227. stream->submit([&](sycl::handler& cgh) {
  228. sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
  229. cgh);
  230. const float eps_ct4 = eps;
  231. cgh.parallel_for(
  232. sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
  233. block_dims),
  234. [=](sycl::nd_item<3> item_ct1)
  235. [[intel::reqd_sub_group_size(WARP_SIZE)]] {
  236. group_norm_f32(x, dst, group_size, ne_elements,
  237. eps_ct4, item_ct1,
  238. get_pointer(s_sum_acc_ct1), work_group_size);
  239. });
  240. });
  241. }
  242. }
  243. static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
  244. const int nrows, const float eps,
  245. queue_ptr stream, int device) {
  246. GGML_ASSERT(ncols % WARP_SIZE == 0);
  247. // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
  248. if (ncols < 1024) {
  249. const sycl::range<3> block_dims(1, 1, WARP_SIZE);
  250. stream->submit([&](sycl::handler& cgh) {
  251. cgh.parallel_for(
  252. sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
  253. block_dims),
  254. [=](sycl::nd_item<3> item_ct1)
  255. [[intel::reqd_sub_group_size(WARP_SIZE)]] {
  256. rms_norm_f32(x, dst, ncols, eps, item_ct1,
  257. nullptr, WARP_SIZE);
  258. });
  259. });
  260. }
  261. else {
  262. const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
  263. const sycl::range<3> block_dims(1, 1, work_group_size);
  264. /*
  265. DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
  266. the limit. To get the device limit, query
  267. info::device::max_work_group_size. Adjust the work-group size if needed.
  268. */
  269. stream->submit([&](sycl::handler& cgh) {
  270. sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
  271. cgh);
  272. cgh.parallel_for(
  273. sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
  274. block_dims),
  275. [=](sycl::nd_item<3> item_ct1)
  276. [[intel::reqd_sub_group_size(WARP_SIZE)]] {
  277. rms_norm_f32(x, dst, ncols, eps, item_ct1,
  278. get_pointer(s_sum_acc_ct1), work_group_size);
  279. });
  280. });
  281. }
  282. }
  283. void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1,
  284. ggml_tensor* dst, const float* src0_dd,
  285. const float* src1_dd, float* dst_dd,
  286. const queue_ptr& main_stream) {
  287. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  288. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  289. const int64_t ne00 = src0->ne[0];
  290. const int64_t nrows = ggml_nrows(src0);
  291. float eps;
  292. memcpy(&eps, dst->op_params, sizeof(float));
  293. norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
  294. (void)src1;
  295. (void)dst;
  296. (void)src1_dd;
  297. }
  298. void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
  299. const ggml_tensor* src1, ggml_tensor* dst,
  300. const float* src0_dd, const float* src1_dd,
  301. float* dst_dd,
  302. const queue_ptr& main_stream) {
  303. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  304. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  305. int num_groups = dst->op_params[0];
  306. int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
  307. group_norm_f32_sycl(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device);
  308. (void)src1;
  309. (void)dst;
  310. (void)src1_dd;
  311. }
  312. void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
  313. const ggml_tensor* src1, ggml_tensor* dst,
  314. const float* src0_dd, const float* src1_dd,
  315. float* dst_dd,
  316. const queue_ptr& main_stream) {
  317. GGML_ASSERT(src0->type == GGML_TYPE_F32);
  318. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  319. const int64_t ne00 = src0->ne[0];
  320. const int64_t nrows = ggml_nrows(src0);
  321. float eps;
  322. memcpy(&eps, dst->op_params, sizeof(float));
  323. rms_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
  324. (void)src1;
  325. (void)dst;
  326. (void)src1_dd;
  327. }