cpy.wgsl 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. enable f16;
  2. @group(0) @binding(0)
  3. var<storage, read_write> src: array<f32>;
  4. @group(0) @binding(1)
  5. var<storage, read_write> dst: array<f16>;
  6. struct Params {
  7. ne: u32, // total number of elements
  8. offset_src: u32, // in elements
  9. offset_dst: u32, // in elements
  10. // Strides (in elements) — may be permuted
  11. stride_src0: u32,
  12. stride_src1: u32,
  13. stride_src2: u32,
  14. stride_src3: u32,
  15. stride_dst0: u32,
  16. stride_dst1: u32,
  17. stride_dst2: u32,
  18. stride_dst3: u32,
  19. // Logical shape (same for both tensors)
  20. ne0: u32,
  21. ne1: u32,
  22. ne2: u32,
  23. ne3: u32,
  24. };
  25. @group(0) @binding(2)
  26. var<uniform> params: Params;
  27. override wg_size: u32;
  28. @compute @workgroup_size(wg_size)
  29. fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
  30. if (gid.x >= params.ne) {
  31. return;
  32. }
  33. var i = gid.x;
  34. let i3 = i / (params.ne2 * params.ne1 * params.ne0);
  35. i = i % (params.ne2 * params.ne1 * params.ne0);
  36. let i2 = i / (params.ne1 * params.ne0);
  37. i = i % (params.ne1 * params.ne0);
  38. let i1 = i / params.ne0;
  39. let i0 = i % params.ne0;
  40. let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
  41. i2 * params.stride_src2 + i3 * params.stride_src3;
  42. let dst_idx = i0 * params.stride_dst0 + i1 * params.stride_dst1 +
  43. i2 * params.stride_dst2 + i3 * params.stride_dst3;
  44. dst[params.offset_dst + dst_idx] = f16(src[params.offset_src + src_idx]);
  45. }