Browse Source

ggml webgpu: support for backend sampling (#18880)

* ggml webgpu: add SOFTPLUS unary operator

Implements SOFTPLUS (log(1 + exp(x))) with f16/f32 support. Uses f32
precision for intermediate calculations to prevent f16 overflow.

* Add shader implementation and 4 variants (f32/f16, inplace/non-inplace)
* Register pipelines and device support
* Follow Vulkan backend numerical stability pattern

* ggml webgpu: add EXPM1 unary operator

Implements EXPM1 (exp(x) - 1) with f16/f32 support.

* Add shader implementation and 4 variants (f32/f16, inplace/non-inplace)
* Register pipelines and device support

* ggml webgpu: add FLOOR unary operator

Implements FLOOR (rounds down to nearest integer) with f16/f32 support.

* Add shader implementation and 4 variants (f32/f16, inplace/non-inplace)
* Register pipelines and device support

* ggml webgpu: add CEIL unary operator

Implements CEIL (rounds up to nearest integer) with f16/f32 support.

* Add shader implementation and 4 variants (f32/f16, inplace/non-inplace)
* Register pipelines and device support

* ggml webgpu: add ROUND unary operator

Implements ROUND (rounds to nearest integer) with f16/f32 support.

* Add shader implementation and 4 variants (f32/f16, inplace/non-inplace)
* Register pipelines and device support

* ggml webgpu: add TRUNC unary operator

Implements TRUNC (truncates towards zero) with f16/f32 support.

* Add shader implementation and 4 variants (f32/f16, inplace/non-inplace)
* Register pipelines and device support

* docs : update WebGPU support for unary operators (FLOOR, CEIL, ROUND, TRUNC, EXPM1, SOFTPLUS)

* Updates to webgpu get_memory

* Add argmax

* Add argmax,cumsum,sum,sum_rows

* Add necessary CPY/GET_ROWS operators

* Support for argsort using multi-pass strategy

* Update set_rows for i32 indices, move to pre-wgsl

* Port unary operators to pre-wgsl and support FILL

* Implement PAD

* Add support for top-k

* clean up, scope pipeline init mutex

* fix newline

* Add support for log

* Update LOG for better precision, and ops doc

---------

Co-authored-by: Abhijit Ramesh <abhijitramesh2k@gmail.com>
Reese Levine 1 week ago
parent
commit
a89002f07b

+ 17 - 16
docs/ops.md

@@ -20,10 +20,10 @@ Legend:
 |                             ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
 |                           ADD_ID | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
 |                           ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
-|                           ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ |  | ❌ | ❌ |
-|                          ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ |  | ❌ | ❌ |
+|                           ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ |  | ❌ | ❌ |
+|                          ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ |  | ❌ | ❌ |
 |                             CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
-|                            CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 |  | ❌ | ❌ |
+|                            CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 |  | ❌ | ❌ |
 |                           CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
 |                             CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
 |                          CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
@@ -36,17 +36,17 @@ Legend:
 |                              CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
 |               CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
 |          CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
-|                           CUMSUM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ |  | ❌ | ❌ |
+|                           CUMSUM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ |  | ❌ | ❌ |
 |                             DIAG | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
 |                    DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
 |                              DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
 |                              DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
 |                              ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
 |                              EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
-|                            EXPM1 | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | ❌ |  | ❌ | ❌ |
-|                             FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ |  | ❌ | ❌ |
-|                   FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ | 🟡 |  | ❌ | ❌ |
-|                            FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 |  | ❌ | ❌ |
+|                            EXPM1 | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | ❌ |  | ❌ | ❌ |
+|                             FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ |  | ❌ | ❌ |
+|                   FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ |
+|                            FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 |  | ❌ | ❌ |
 |                GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
 |                            GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
 |                        GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
@@ -63,7 +63,7 @@ Legend:
 |                        IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
 |                          L2_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
 |                       LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
-|                              LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | ✅ |  | ❌ | ❌ |
+|                              LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | ✅ |  | ❌ | ❌ |
 |                             MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
 |                              MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
 |                          MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
@@ -73,8 +73,9 @@ Legend:
 |                   OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
 |                     OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
 |                         OUT_PROD | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ | 🟡 |
-|                              PAD | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ |  | ❌ | ❌ |
+|                              PAD | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ |  | ❌ | ❌ |
 |                   PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
+|                          POOL_1D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
 |                          POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
 |                            REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
 |                             RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
@@ -85,7 +86,7 @@ Legend:
 |                             ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
 |                             ROPE | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
 |                        ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
-|                            ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 |  | ❌ | ❌ |
+|                            ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 |  | ❌ | ❌ |
 |                        RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
 |                        RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
 |                            SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
@@ -96,7 +97,7 @@ Legend:
 |                             SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
 |                        SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
 |                              SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
-|                         SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 |  | ❌ | ❌ |
+|                         SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 |  | ❌ | ❌ |
 |                         SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
 |                    SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
 |                        SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
@@ -106,14 +107,14 @@ Legend:
 |                         SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
 |                             STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
 |                              SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
-|                              SUM | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 |  | ❌ | ❌ |
-|                         SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ |  | ❌ | ❌ |
+|                              SUM | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
+|                         SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ |  | ❌ | ❌ |
 |                           SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
 |                       SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
 |                             TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
 |               TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
-|                            TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | 🟡 |  | ❌ | ❌ |
+|                            TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | 🟡 |  | ❌ | ❌ |
 |                              TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
-|                            TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 |  | ❌ | ❌ |
+|                            TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 |  | ❌ | ❌ |
 |                          UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | ❌ |
 |                            XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |

File diff suppressed because it is too large
+ 599 - 501
docs/ops/WebGPU.csv


+ 335 - 36
ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp

@@ -9,12 +9,28 @@
 
 #define GGML_WEBGPU_F16_SIZE_BYTES                   2
 #define GGML_WEBGPU_F32_SIZE_BYTES                   4
+#define GGML_WEBGPU_I32_SIZE_BYTES                   4
 #define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u
 #define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE     128u
 // Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing.
 #define GGML_WEBGPU_KV_SEQ_PAD                       256u
 
-struct ggml_webgpu_flash_attn_shader_lib_context {
+#define GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE 512u
+
+struct ggml_webgpu_processed_shader {
+    std::string wgsl;
+    std::string variant;
+    void *      decisions;
+};
+
+// Same hash combine function as in boost
+template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {
+    seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
+}
+
+/** FlashAttention */
+
+struct ggml_webgpu_flash_attn_pipeline_key {
     ggml_type kv_type;
     uint32_t  head_dim_qk;
     uint32_t  head_dim_v;
@@ -22,11 +38,35 @@ struct ggml_webgpu_flash_attn_shader_lib_context {
     bool      has_mask;
     bool      has_sinks;
     bool      uses_logit_softcap;
-    uint32_t  sg_mat_m;
-    uint32_t  sg_mat_n;
-    uint32_t  sg_mat_k;
-    size_t    wg_mem_limit_bytes;
-    uint32_t  max_subgroup_size;
+
+    bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
+        return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
+               kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks &&
+               uses_logit_softcap == other.uses_logit_softcap;
+    }
+};
+
+struct ggml_webgpu_flash_attn_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.kv_type);
+        ggml_webgpu_hash_combine(seed, key.head_dim_qk);
+        ggml_webgpu_hash_combine(seed, key.head_dim_v);
+        ggml_webgpu_hash_combine(seed, key.kv_direct);
+        ggml_webgpu_hash_combine(seed, key.has_mask);
+        ggml_webgpu_hash_combine(seed, key.has_sinks);
+        ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
+        return seed;
+    }
+};
+
+struct ggml_webgpu_flash_attn_shader_lib_context {
+    ggml_webgpu_flash_attn_pipeline_key key;
+    uint32_t                            sg_mat_m;
+    uint32_t                            sg_mat_n;
+    uint32_t                            sg_mat_k;
+    size_t                              wg_mem_limit_bytes;
+    uint32_t                            max_subgroup_size;
 };
 
 struct ggml_webgpu_flash_attn_shader_decisions {
@@ -35,12 +75,6 @@ struct ggml_webgpu_flash_attn_shader_decisions {
     uint32_t wg_size = 0;
 };
 
-struct ggml_webgpu_processed_shader {
-    std::string                             wgsl;
-    std::string                             variant;
-    ggml_webgpu_flash_attn_shader_decisions decisions;
-};
-
 // This is exposed because it's necessary in supports_op
 inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
                                                   uint32_t kv_tile,
@@ -66,15 +100,16 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
 }
 
 static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) {
-    const size_t limit_bytes  = context.wg_mem_limit_bytes;
-    const size_t q_tile       = context.sg_mat_m;
-    const size_t base_q_bytes = (context.head_dim_qk + context.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
-                                2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
+    const size_t limit_bytes = context.wg_mem_limit_bytes;
+    const size_t q_tile      = context.sg_mat_m;
+    const size_t base_q_bytes =
+        (context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
+        2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
     size_t bytes_per_kv = 0;
-    if (!context.kv_direct) {
-        bytes_per_kv += std::max(context.head_dim_qk, context.head_dim_v);
+    if (!context.key.kv_direct) {
+        bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v);
     }
-    if (context.has_mask) {
+    if (context.key.has_mask) {
         bytes_per_kv += q_tile;
     }
     bytes_per_kv += q_tile;
@@ -90,7 +125,7 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
     std::vector<std::string> defines;
     std::string              variant = "flash_attn";
 
-    switch (context.kv_type) {
+    switch (context.key.kv_type) {
         case GGML_TYPE_F32:
             defines.push_back("KV_F32");
             break;
@@ -106,32 +141,31 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
         default:
             GGML_ABORT("Unsupported KV type for flash attention shader");
     }
-    variant += std::string("_") + ggml_type_name(context.kv_type);
+    variant += std::string("_") + ggml_type_name(context.key.kv_type);
 
-    if (context.has_mask) {
+    if (context.key.has_mask) {
         defines.push_back("MASK");
         variant += "_mask";
     }
-    if (context.has_sinks) {
+    if (context.key.has_sinks) {
         defines.push_back("SINKS");
         variant += "_sinks";
     }
-    if (context.uses_logit_softcap) {
+    if (context.key.uses_logit_softcap) {
         defines.push_back("LOGIT_SOFTCAP");
         variant += "_lgsc";
     }
 
-    if (context.kv_direct) {
+    if (context.key.kv_direct) {
         defines.push_back("KV_DIRECT");
         variant += "_kvdirect";
     }
 
-    defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.head_dim_qk));
-    variant += std::string("_hsqk") + std::to_string(context.head_dim_qk);
-
-    defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.head_dim_v));
-    variant += std::string("_hsv") + std::to_string(context.head_dim_v);
+    defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk));
+    variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk);
 
+    defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v));
+    variant += std::string("_hsv") + std::to_string(context.key.head_dim_v);
     // For now these are not part of the variant name
     defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
     defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
@@ -141,7 +175,7 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
     uint32_t q_tile  = context.sg_mat_m;
     uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context),
                                 context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
-    if (context.kv_direct) {
+    if (context.key.kv_direct) {
         GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
         // Avoids having to use bounds-checks and decreasing performance for direct KV loads
         while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
@@ -158,11 +192,276 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
     defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
 
     ggml_webgpu_processed_shader result;
-    result.wgsl              = preprocessor.preprocess(shader_src, defines);
-    result.variant           = variant;
-    result.decisions.q_tile  = q_tile;
-    result.decisions.kv_tile = kv_tile;
-    result.decisions.wg_size = wg_size;
+    result.wgsl                                         = preprocessor.preprocess(shader_src, defines);
+    result.variant                                      = variant;
+    ggml_webgpu_flash_attn_shader_decisions * decisions = new ggml_webgpu_flash_attn_shader_decisions();
+    decisions->q_tile                                   = q_tile;
+    decisions->kv_tile                                  = kv_tile;
+    decisions->wg_size                                  = wg_size;
+    result.decisions                                    = decisions;
+    return result;
+}
+
+/** Generic **/
+
+struct ggml_webgpu_generic_shader_lib_context {
+    int      vec4;
+    uint32_t max_wg_size;
+};
+
+struct ggml_webgpu_generic_shader_decisions {
+    uint32_t wg_size;
+};
+
+inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_generic_shader(
+    pre_wgsl::Preprocessor &                       preprocessor,
+    const char *                                   shader_src,
+    const ggml_webgpu_generic_shader_lib_context & context,
+    const std::string &                            base_variant) {
+    std::vector<std::string> defines;
+    std::string              variant = base_variant;
+
+    if (context.vec4) {
+        defines.push_back("VEC4");
+        variant += "_vec";
+    }
+
+    defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+    ggml_webgpu_processed_shader result;
+    result.wgsl    = preprocessor.preprocess(shader_src, defines);
+    result.variant = variant;
+    return result;
+}
+
+/** Pad **/
+
+struct ggml_webgpu_pad_pipeline_key {
+    bool circular;
+
+    bool operator==(const ggml_webgpu_pad_pipeline_key & other) const { return circular == other.circular; }
+};
+
+struct ggml_webgpu_pad_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_pad_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.circular);
+        return seed;
+    }
+};
+
+struct ggml_webgpu_pad_shader_lib_context {
+    ggml_webgpu_pad_pipeline_key key;
+    uint32_t                     max_wg_size;
+};
+
+inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_pad_shader(
+    pre_wgsl::Preprocessor &                   preprocessor,
+    const char *                               shader_src,
+    const ggml_webgpu_pad_shader_lib_context & context) {
+    std::vector<std::string> defines;
+    std::string              variant = "pad";
+
+    if (context.key.circular) {
+        defines.push_back("CIRCULAR");
+        variant += "_circular";
+    }
+
+    defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+    ggml_webgpu_processed_shader result;
+    result.wgsl                                      = preprocessor.preprocess(shader_src, defines);
+    result.variant                                   = variant;
+    ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions();
+    decisions->wg_size                               = context.max_wg_size;
+    result.decisions                                 = decisions;
+    return result;
+}
+
+/** Argsort **/
+
+struct ggml_webgpu_argsort_shader_lib_context {
+    uint32_t max_wg_size;
+    size_t   wg_mem_limit_bytes;
+    int32_t  order;
+};
+
+struct ggml_webgpu_argsort_shader_decisions {
+    uint32_t wg_size = 0;
+};
+
+inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_shader(
+    pre_wgsl::Preprocessor &                       preprocessor,
+    const char *                                   shader_src,
+    const ggml_webgpu_argsort_shader_lib_context & context) {
+    std::vector<std::string> defines;
+    std::string              variant = "argsort";
+    defines.push_back(std::string("ORDER=") + std::to_string(context.order));
+    variant += std::string("_order") + std::to_string(context.order);
+    uint32_t wg_size = 1;
+    while (wg_size * 2 <= context.max_wg_size &&
+           wg_size * GGML_WEBGPU_I32_SIZE_BYTES <= context.wg_mem_limit_bytes / 2) {
+        wg_size *= 2;
+    }
+    defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
+    ggml_webgpu_processed_shader result;
+    result.wgsl                                      = preprocessor.preprocess(shader_src, defines);
+    result.variant                                   = variant;
+    ggml_webgpu_argsort_shader_decisions * decisions = new ggml_webgpu_argsort_shader_decisions();
+    decisions->wg_size                               = wg_size;
+    result.decisions                                 = decisions;
+    return result;
+}
+
+inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_merge_shader(
+    pre_wgsl::Preprocessor &                       preprocessor,
+    const char *                                   shader_src,
+    const ggml_webgpu_argsort_shader_lib_context & context) {
+    std::vector<std::string> defines;
+    std::string              variant = "argsort_merge";
+    defines.push_back(std::string("ORDER=") + std::to_string(context.order));
+    variant += std::string("_order") + std::to_string(context.order);
+    uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size);
+    defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
+    ggml_webgpu_processed_shader result;
+    result.wgsl                                      = preprocessor.preprocess(shader_src, defines);
+    result.variant                                   = variant;
+    ggml_webgpu_argsort_shader_decisions * decisions = new ggml_webgpu_argsort_shader_decisions();
+    decisions->wg_size                               = wg_size;
+    result.decisions                                 = decisions;
+    return result;
+}
+
+/** Set Rows **/
+
+struct ggml_webgpu_set_rows_pipeline_key {
+    int dst_type;
+    int vec4;
+    int i64_idx;
+
+    bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const {
+        return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx;
+    }
+};
+
+struct ggml_webgpu_set_rows_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_set_rows_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.dst_type);
+        ggml_webgpu_hash_combine(seed, key.vec4);
+        ggml_webgpu_hash_combine(seed, key.i64_idx);
+        return seed;
+    }
+};
+
+struct ggml_webgpu_set_rows_shader_lib_context {
+    ggml_webgpu_set_rows_pipeline_key key;
+    uint32_t                          max_wg_size;
+};
+
+inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_set_rows_shader(
+    pre_wgsl::Preprocessor &                        preprocessor,
+    const char *                                    shader_src,
+    const ggml_webgpu_set_rows_shader_lib_context & context) {
+    std::vector<std::string> defines;
+    std::string              variant = "set_rows";
+
+    switch (context.key.dst_type) {
+        case GGML_TYPE_F32:
+            defines.push_back("DST_F32");
+            variant += "_dstf32";
+            break;
+        case GGML_TYPE_F16:
+            defines.push_back("DST_F16");
+            variant += "_dstf16";
+            break;
+        default:
+            GGML_ABORT("Unsupported dst type for set_rows shader");
+    }
+
+    if (context.key.vec4) {
+        defines.push_back("VEC4");
+        variant += "_vec";
+    }
+    if (context.key.i64_idx) {
+        defines.push_back("I64_IDX");
+        variant += "_i64idx";
+    }
+
+    defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+    ggml_webgpu_processed_shader result;
+    result.wgsl                                      = preprocessor.preprocess(shader_src, defines);
+    result.variant                                   = variant;
+    ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions();
+    decisions->wg_size                               = context.max_wg_size;
+    result.decisions                                 = decisions;
+    return result;
+}
+
+struct ggml_webgpu_unary_pipeline_key {
+    int  type;
+    int  op;
+    bool is_unary;  // many unary operators fall under the GGML_OP_UNARY umbrella
+    bool inplace;
+
+    bool operator==(const ggml_webgpu_unary_pipeline_key & other) const {
+        return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace;
+    }
+};
+
+struct ggml_webgpu_unary_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_unary_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.type);
+        ggml_webgpu_hash_combine(seed, key.op);
+        ggml_webgpu_hash_combine(seed, key.is_unary);
+        ggml_webgpu_hash_combine(seed, key.inplace);
+        return seed;
+    }
+};
+
+struct ggml_webgpu_unary_shader_lib_context {
+    ggml_webgpu_unary_pipeline_key key;
+    uint32_t                       max_wg_size;
+};
+
+inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_unary_shader(
+    pre_wgsl::Preprocessor &                     preprocessor,
+    const char *                                 shader_src,
+    const ggml_webgpu_unary_shader_lib_context & context) {
+    std::vector<std::string> defines;
+    std::string              variant = context.key.is_unary ? ggml_unary_op_name((ggml_unary_op) context.key.op) :
+                                                              ggml_op_name((ggml_op) context.key.op);
+    // Operation-specific behavior
+    defines.push_back(variant);
+
+    switch (context.key.type) {
+        case GGML_TYPE_F32:
+            defines.push_back("TYPE_F32");
+            variant += "_f32";
+            break;
+        case GGML_TYPE_F16:
+            defines.push_back("TYPE_F16");
+            variant += "_f16";
+            break;
+        default:
+            GGML_ABORT("Unsupported type for unary shader");
+    }
+
+    if (context.key.inplace) {
+        defines.push_back("INPLACE");
+        variant += "_inplace";
+    }
+
+    defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+    ggml_webgpu_processed_shader result;
+    result.wgsl                                      = preprocessor.preprocess(shader_src, defines);
+    result.variant                                   = variant;
+    ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions();
+    decisions->wg_size                               = context.max_wg_size;
+    result.decisions                                 = decisions;
     return result;
 }
 

File diff suppressed because it is too large
+ 634 - 166
ggml/src/ggml-webgpu/ggml-webgpu.cpp


+ 72 - 0
ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl

@@ -0,0 +1,72 @@
+@group(0) @binding(0)
+#ifdef VEC4
+var<storage, read_write> src: array<vec4<f32>>;
+#define VEC_SIZE 4
+#else
+var<storage, read_write> src: array<f32>;
+#define VEC_SIZE 1
+#endif
+
+@group(0) @binding(1)
+var<storage, read_write> dst: array<i32>;
+
+struct Params {
+    offset_src: u32, // in elements
+    offset_dst: u32, // in elements
+    ne0: u32,
+};
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+const FLOAT_MIN: f32 = -1.0e9;
+
+struct Pair {
+    value: f32,
+    index: i32
+};
+
+var<workgroup> shared_max: array<Pair, WG_SIZE>;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(workgroup_id) wid: vec3<u32>,
+        @builtin(local_invocation_id) lid: vec3<u32>) {
+    let row_idx = params.offset_src + wid.x * params.ne0;
+    var local_pair = Pair(FLOAT_MIN, -1);
+#ifdef VEC4
+    for (var col = lid.x; col < params.ne0/VEC_SIZE; col += WG_SIZE) {
+        let vec_val = src[row_idx / VEC_SIZE + col];
+        for (var v = 0u; v < VEC_SIZE; v++) {
+            let val = vec_val[v];
+            if (val >= local_pair.value) {
+                local_pair = Pair(val, i32(col * VEC_SIZE + v));
+            }
+        }
+    }
+#else
+    for (var col = lid.x; col < params.ne0; col += WG_SIZE) {
+        if (src[row_idx + col] >= local_pair.value) {
+            local_pair = Pair(src[row_idx + col], i32(col));
+        }
+    }
+#endif
+    shared_max[lid.x] = local_pair;
+    workgroupBarrier();
+    var offset: u32 = WG_SIZE >> 1;
+    while (offset > 0) {
+        if (lid.x < offset) {
+            let a = shared_max[lid.x];
+            let b = shared_max[lid.x + offset];
+            if (b.value > a.value) {
+                shared_max[lid.x] = b;
+            } else if (b.value == a.value && b.index > a.index) {
+                shared_max[lid.x] = b;
+            }
+        }
+        workgroupBarrier();
+        offset >>= 1;
+    }
+    if (lid.x == 0u) {
+        dst[params.offset_dst + wid.x] = shared_max[0].index;
+    }
+}

+ 106 - 0
ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl

@@ -0,0 +1,106 @@
+@group(0) @binding(0)
+var<storage, read_write> src: array<f32>;
+
+@group(0) @binding(1)
+var<storage, read_write> dst: array<i32>;
+
+struct Params {
+    offset_src: u32, // in elements
+    offset_dst: u32, // in elements
+
+    stride_src1: u32,
+    stride_src2: u32,
+    stride_src3: u32,
+
+    stride_dst1: u32,
+    stride_dst2: u32,
+    stride_dst3: u32,
+
+    // src/dst dimensions
+    src_ne0: u32,
+    ne1: u32,
+    ne2: u32,
+
+    ne0: u32,
+    top_k: u32,
+
+    npr: u32,   // tiles per row
+    nrows: u32
+};
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+var<workgroup> shmem_idx: array<u32, WG_SIZE>;
+
+#if ORDER == 0
+#define EXTREME_VALUE 1e30
+#define SWAP_COMPARE_UP >
+#define SWAP_COMPARE_DOWN <
+#else
+#define EXTREME_VALUE -1e30
+#define SWAP_COMPARE_UP <
+#define SWAP_COMPARE_DOWN >
+#endif
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(workgroup_id) wid: vec3<u32>,
+        @builtin(num_workgroups) num_wg: vec3<u32>,
+        @builtin(local_invocation_id) lid: vec3<u32>) {
+    let linear = wid.x + wid.y * num_wg.x;
+    // guard against overprovisioned workgroups
+    if (linear >= params.npr * params.nrows) {
+        return;
+    }
+    let tile = linear % params.npr;
+    var row = linear / params.npr;
+    let i3 = row / (params.ne2 * params.ne1);
+    row = row % (params.ne2 * params.ne1);
+    let i2 = row / params.ne1;
+    let i1 = row % params.ne1;
+
+    let row_base = params.offset_src +
+        i1 * params.stride_src1 +
+        i2 * params.stride_src2 +
+        i3 * params.stride_src3;
+
+    let tile_base = tile * WG_SIZE;
+    let idx = tile_base + lid.x;
+    shmem_idx[lid.x] = select(params.src_ne0, idx, idx < params.src_ne0);
+    workgroupBarrier();
+
+    var k = 2u;
+    while (k <= WG_SIZE) {
+        var j = k >> 1;
+        while (j > 0) {
+            let ixj = lid.x ^ j;
+            if (ixj > lid.x) {
+                let dir_up = (lid.x & k) == 0;
+                let a_idx = shmem_idx[lid.x];
+                let b_idx = shmem_idx[ixj];
+                let a_val = select(EXTREME_VALUE, src[row_base + a_idx], a_idx < params.src_ne0);
+                let b_val = select(EXTREME_VALUE, src[row_base + b_idx], b_idx < params.src_ne0);
+                let should_swap = select(
+                    (a_val SWAP_COMPARE_DOWN b_val),
+                    (a_val SWAP_COMPARE_UP b_val),
+                    dir_up);
+                if (should_swap) {
+                    shmem_idx[lid.x] = b_idx;
+                    shmem_idx[ixj] = a_idx;
+                }
+            }
+            workgroupBarrier();
+            j >>= 1;
+        }
+        k <<= 1;
+    }
+
+    let out_idx = tile * params.top_k + lid.x;
+    if (out_idx < params.ne0 && lid.x < params.top_k) {
+        let row_dst = params.offset_dst +
+            i1 * params.stride_dst1 +
+            i2 * params.stride_dst2 +
+            i3 * params.stride_dst3;
+        dst[row_dst + out_idx] = i32(shmem_idx[lid.x]);
+    }
+}

+ 134 - 0
ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl

@@ -0,0 +1,134 @@
+@group(0) @binding(0)
+var<storage, read_write> src: array<f32>;
+
+@group(0) @binding(1)
+var<storage, read_write> idx_in: array<i32>;
+
+@group(0) @binding(2)
+var<storage, read_write> idx_out: array<i32>;
+
+struct Params {
+    offset_src: u32, // in elements
+    offset_in: u32,  // in elements
+    offset_out: u32, // in elements
+
+    stride_src1: u32,
+    stride_src2: u32,
+    stride_src3: u32,
+
+    stride_idx1: u32,
+    stride_idx2: u32,
+    stride_idx3: u32,
+
+    stride_out1: u32,
+    stride_out2: u32,
+    stride_out3: u32,
+
+    ne0: u32,
+    ne1: u32,
+    ne2: u32,
+
+    top_k: u32,
+
+    len: u32,
+    nm: u32,
+    nrows: u32
+};
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+
+fn take_left(a_idx: i32, b_idx: i32, row_base: u32) -> bool {
+    let a_val = src[row_base + u32(a_idx)];
+    let b_val = src[row_base + u32(b_idx)];
+#if ORDER == 0
+    return a_val <= b_val;
+#else
+    return a_val >= b_val;
+#endif
+}
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(workgroup_id) wid: vec3<u32>,
+        @builtin(num_workgroups) num_wg: vec3<u32>,
+        @builtin(local_invocation_id) lid: vec3<u32>) {
+    let linear = wid.x + wid.y * num_wg.x;
+    // guard against overprovisioned workgroups
+    if (linear >= params.nm * params.nrows) {
+        return;
+    }
+
+    let start = (linear % params.nm) * params.len * 2;
+    let len0 = min(params.len, params.ne0 - start);
+    let rem1 = select(0, params.ne0 - (start + params.len), params.ne0 > (start + params.len));
+    let len1 = min(params.len, rem1);
+    let total = len0 + len1;
+    let chunk = (total + WG_SIZE - 1u) / WG_SIZE;
+    let k0 = lid.x * chunk;
+    let k1 = min(min(k0 + chunk, total), params.top_k);
+    // guard against overprovisioned threads
+    if (k0 >= params.top_k || k0 >= total) {
+        return;
+    }
+
+    var row = linear / params.nm;
+    let i3 = row / (params.ne2 * params.ne1);
+    row = row % (params.ne2 * params.ne1);
+    let i2 = row / params.ne1;
+    let i1 = row % params.ne1;
+
+    let row_src = params.offset_src +
+        i1 * params.stride_src1 +
+        i2 * params.stride_src2 +
+        i3 * params.stride_src3;
+
+    let row_in = params.offset_in +
+        i1 * params.stride_idx1 +
+        i2 * params.stride_idx2 +
+        i3 * params.stride_idx3;
+
+    let row_out = params.offset_out +
+        i1 * params.stride_out1 +
+        i2 * params.stride_out2 +
+        i3 * params.stride_out3;
+
+
+    var low: u32 = select(0, k0 - len1, k0 > len1);
+    var high: u32 = min(k0, len0);
+
+    while (low < high) {
+        let mid = (low + high) >> 1;
+        let idx0 = idx_in[row_in + start + mid];
+        let idx1 = idx_in[row_in + start + params.len + (k0 - mid - 1)];
+        if (take_left(idx0, idx1, row_src)) {
+            low = mid + 1;
+        } else {
+            high = mid;
+        }
+    }
+
+    var i = low;
+    var j = k0 - i;
+    var k = k0;
+    while (k < k1) {
+        var take_l = false;
+        if (i >= len0) {
+            take_l = false;
+        } else if (j >= len1) {
+            take_l = true;
+        } else {
+            let idx0 = idx_in[row_in + start + i];
+            let idx1 = idx_in[row_in + start + params.len + j];
+            take_l = take_left(idx0, idx1, row_src);
+        }
+
+        let out_idx = select(
+            idx_in[row_in + start + params.len + j],
+            idx_in[row_in + start + i],
+            take_l);
+        idx_out[row_out + start + k] = out_idx;
+        i = select(i, i + 1, take_l);
+        j = select(j + 1, j, take_l);
+        k += 1;
+    }
+}

+ 6 - 0
ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl

@@ -7,6 +7,12 @@
       "DST_TYPE": "f32"
     }
   },
+  {
+    "REPLS": {
+      "SRC_TYPE": "f32",
+      "DST_TYPE": "i32"
+    }
+  },
   {
     "REPLS": {
       "SRC_TYPE": "f32",

+ 66 - 0
ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl

@@ -0,0 +1,66 @@
+@group(0) @binding(0)
+var<storage, read_write> src: array<f32>;
+
+@group(0) @binding(1)
+var<storage, read_write> dst: array<f32>;
+
+struct Params {
+    offset_src: u32, // in elements
+    offset_dst: u32, // in elements
+    ne0: u32,
+};
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+var<workgroup> shared_sum: array<f32, WG_SIZE>;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(workgroup_id) wid: vec3<u32>,
+        @builtin(local_invocation_id) lid: vec3<u32>) {
+    let row_idx = params.offset_src + wid.x * params.ne0;
+    let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE;
+    var local_sum: f32 = 0.0;
+    for (var col = lid.x * elems; col < (lid.x + 1) * elems && col < params.ne0; col ++) {
+        local_sum += src[row_idx + col];
+    }
+    shared_sum[lid.x] = local_sum;
+    workgroupBarrier();
+
+    // upsweep
+    var offset = 1u;
+    while (offset < WG_SIZE) {
+        let idx = (lid.x + 1) * offset * 2 - 1;
+        if (idx < WG_SIZE) {
+            shared_sum[idx] = shared_sum[idx] + shared_sum[idx - offset];
+        }
+        workgroupBarrier();
+        offset <<= 1;
+    }
+
+    // set last to 0 for exclusive sum
+    if (lid.x == 0) {
+        shared_sum[WG_SIZE - 1] = 0.0;
+    }
+    workgroupBarrier();
+
+    // downsweep
+    offset = WG_SIZE >> 1;
+    while (offset > 0) {
+        let idx = (lid.x + 1) * offset * 2 - 1;
+        if (idx < WG_SIZE) {
+            let t = shared_sum[idx - offset];
+            shared_sum[idx - offset] = shared_sum[idx];
+            shared_sum[idx] = shared_sum[idx] + t;
+        }
+        workgroupBarrier();
+        offset = offset >> 1;
+    }
+
+    // shared_sum[lid] is exclusive prefix sum up to this thread.
+    var running_sum = shared_sum[lid.x];
+    for (var col = lid.x * elems; col < (lid.x + 1) * elems && col < params.ne0; col ++) {
+        running_sum += src[row_idx + col];
+        dst[params.offset_dst + wid.x * params.ne0 + col] = running_sum;
+    }
+}

+ 86 - 0
ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl

@@ -0,0 +1,86 @@
+@group(0) @binding(0)
+var<storage, read_write> src: array<f32>;
+
+@group(0) @binding(1)
+var<storage, read_write> dst: array<f32>;
+
+struct Params {
+    ne: u32,            // total number of elements
+    offset_src: u32,    // in elements
+    offset_dst: u32,    // in elements
+
+    // Strides (in elements)
+    stride_src0: u32,
+    stride_src1: u32,
+    stride_src2: u32,
+    stride_src3: u32,
+
+    // Logical shapes
+    src_ne0: u32,
+    src_ne1: u32,
+    src_ne2: u32,
+    src_ne3: u32,
+
+    dst_ne0: u32,
+    dst_ne1: u32,
+    dst_ne2: u32,
+    dst_ne3: u32,
+
+    // Pad sizes (in elements)
+    lp0: u32,
+    rp0: u32,
+    lp1: u32,
+    rp1: u32,
+    lp2: u32,
+    rp2: u32,
+    lp3: u32,
+    rp3: u32,
+};
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+fn wrap_around(idx: i32, n: u32) -> u32 {
+    return u32(idx + i32(n)) % n;
+}
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+    if (gid.x >= params.ne) {
+        return;
+    }
+
+    var i = gid.x;
+    let dst_plane = params.dst_ne2 * params.dst_ne1 * params.dst_ne0;
+    let i3 = i / dst_plane;
+    i = i % dst_plane;
+    let i2 = i / (params.dst_ne1 * params.dst_ne0);
+    i = i % (params.dst_ne1 * params.dst_ne0);
+    let i1 = i / params.dst_ne0;
+    let i0 = i % params.dst_ne0;
+
+    var value: f32 = 0.0;
+
+#ifdef CIRCULAR
+    let ci0 = wrap_around(i32(i0) - i32(params.lp0), params.src_ne0);
+    let ci1 = wrap_around(i32(i1) - i32(params.lp1), params.src_ne1);
+    let ci2 = wrap_around(i32(i2) - i32(params.lp2), params.src_ne2);
+    let ci3 = wrap_around(i32(i3) - i32(params.lp3), params.src_ne3);
+    let circular_src_idx = ci0 * params.stride_src0 + ci1 * params.stride_src1 +
+                           ci2 * params.stride_src2 + ci3 * params.stride_src3;
+    value = src[params.offset_src + circular_src_idx];
+#else
+    let is_src =
+        (i0 >= params.lp0 && i0 < params.dst_ne0 - params.rp0) &&
+        (i1 >= params.lp1 && i1 < params.dst_ne1 - params.rp1) &&
+        (i2 >= params.lp2 && i2 < params.dst_ne2 - params.rp2) &&
+        (i3 >= params.lp3 && i3 < params.dst_ne3 - params.rp3);
+    if (is_src) {
+        let src_idx = (i0 - params.lp0) * params.stride_src0 + (i1 - params.lp1) * params.stride_src1 +
+                      (i2 - params.lp2) * params.stride_src2 + (i3 - params.lp3) * params.stride_src3;
+        value = src[params.offset_src + src_idx];
+    }
+#endif
+
+    dst[params.offset_dst + gid.x] = value;
+}

+ 35 - 38
ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl → ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl

@@ -1,41 +1,37 @@
-#define(VARIANTS)
-
-[
-  {
-    "SHADER_SUFFIX": "f16_vec",
-    "REPLS": {
-      "TYPE" : "vec4<f32>",
-      "DST_TYPE": "vec4<f16>",
-      "VEC_SIZE": 4
-    }
-  },
-  {
-    "SHADER_SUFFIX": "f16",
-    "REPLS": {
-      "TYPE" : "f32",
-      "DST_TYPE": "f16",
-      "VEC_SIZE": 1
-    }
-  }
-]
-
-#end(VARIANTS)
-
-#define(SHADER)
-
 enable f16;
 
+#ifdef DST_F32
+#define DST_INNER_TYPE f32
+#else
+#define DST_INNER_TYPE f16
+#endif
+
+#ifdef VEC4
+#define SRC_TYPE vec4<f32>
+#define DST_TYPE vec4<DST_INNER_TYPE>
+#define VEC_SIZE 4
+#else
+#define SRC_TYPE f32
+#define DST_TYPE DST_INNER_TYPE
+#define VEC_SIZE 1
+#endif
+
 @group(0) @binding(0)
-var<storage, read_write> src: array<{{TYPE}}>;
+var<storage, read_write> src: array<SRC_TYPE>;
 
 @group(0) @binding(1)
 var<storage, read_write> idx: array<u32>;
 
 @group(0) @binding(2)
-var<storage, read_write> dst: array<{{DST_TYPE}}>;
+var<storage, read_write> dst: array<DST_TYPE>;
 
+#ifdef I64_IDX
 @group(0) @binding(3)
 var<storage, read_write> error: atomic<u32>;
+#define PARAMS_BINDING 4
+#else
+#define PARAMS_BINDING 3
+#endif
 
 struct Params {
     offset_src: u32, // in elements
@@ -66,18 +62,17 @@ struct Params {
     idx2: u32,
 };
 
-@group(0) @binding(4)
+@group(0) @binding(PARAMS_BINDING)
 var<uniform> params: Params;
 
-override wg_size: u32;
-@compute @workgroup_size(wg_size)
+@compute @workgroup_size(WG_SIZE)
 fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
-    if (gid.x >= (params.ne3 * params.ne2 * params.n_rows * params.ne0) / {{VEC_SIZE}}) {
+    if (gid.x >= (params.ne3 * params.ne2 * params.n_rows * params.ne0) / VEC_SIZE) {
         return;
     }
 
     // getting the row from gid
-    let elems_per_row = params.ne0 / {{VEC_SIZE}};
+    let elems_per_row = params.ne0 / VEC_SIZE;
     var i = gid.x / elems_per_row;
 
     let i_src3 = i / (params.ne2 * params.n_rows);
@@ -90,9 +85,10 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
     let i_idx1 = i_src2 % params.idx1;
     let i_idx0 = i_src1;
 
+#ifdef I64_IDX
     let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2;
 
-    let idx_high_val = idx[idx_high];
+    let idx_val = idx[idx_high];
     let idx_low_val = idx[idx_high + 1];
 
     if (idx_low_val != 0) {
@@ -100,13 +96,14 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
         atomicStore(&error, 1);
         return;
     }
+#else
+    let idx_i = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2;
+    let idx_val = idx[idx_i];
+#endif
 
-    let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
+    let i_dst_row = params.offset_dst + idx_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
     let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;
 
     let col_idx = (gid.x % elems_per_row);
-    dst[i_dst_row/{{VEC_SIZE}} + col_idx] = {{DST_TYPE}}(src[i_src_row/{{VEC_SIZE}} + col_idx]);
+    dst[i_dst_row/VEC_SIZE + col_idx] = DST_TYPE(src[i_src_row/VEC_SIZE + col_idx]);
 }
-
-#end(SHADER)
-

+ 55 - 0
ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl

@@ -0,0 +1,55 @@
+@group(0) @binding(0)
+var<storage, read_write> src: array<f32>;
+
+@group(0) @binding(1)
+var<storage, read_write> dst: array<f32>;
+
+struct Params {
+    offset_src: u32, // in elements
+    offset_dst: u32, // in elements
+
+    // Strides (in elements)
+    stride_src1: u32,
+    stride_src2: u32,
+    stride_src3: u32,
+
+    ne0: u32,
+    ne1: u32,
+    ne2: u32
+};
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+var<workgroup> shared_sum: array<f32, WG_SIZE>;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(workgroup_id) wid: vec3<u32>,
+        @builtin(local_invocation_id) lid: vec3<u32>) {
+
+    var i = wid.x;
+    let i3 = i / (params.ne2 * params.ne1);
+    i = i % (params.ne2 * params.ne1);
+    let i2 = i / params.ne1;
+    let i1 = i % params.ne1;
+    let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
+    var local_sum: f32 = 0.0;
+    for (var col = lid.x; col < params.ne0; col += WG_SIZE) {
+        local_sum += src[i_src_row + col];
+    }
+    shared_sum[lid.x] = local_sum;
+    workgroupBarrier();
+    // reduce within workgroup
+    var offset: u32 = WG_SIZE >> 1;
+    while (offset > 0) {
+        if (lid.x < offset) {
+            shared_sum[lid.x] = shared_sum[lid.x] + shared_sum[lid.x + offset];
+        }
+        workgroupBarrier();
+        offset >>= 1;
+    }
+
+    if (lid.x == 0) {
+        dst[params.offset_dst + wid.x] = shared_sum[0];
+    }
+}

+ 179 - 0
ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl

@@ -0,0 +1,179 @@
+#ifdef TYPE_F16
+enable f16;
+#define TYPE f16
+#else
+#define TYPE f32
+#endif
+
+
+@group(0) @binding(0)
+var<storage, read_write> src: array<TYPE>;
+
+#ifndef INPLACE
+@group(0) @binding(1)
+var<storage, read_write> dst: array<TYPE>;
+#define PARAMS_BINDING 2
+#else
+#define PARAMS_BINDING 1
+#endif
+
+struct Params {
+    ne: u32,            // total number of elements
+    offset_src: u32,    // in elements
+    offset_dst: u32,    // in elements
+
+    // Strides (in elements)
+    stride_src0: u32,
+    stride_src1: u32,
+    stride_src2: u32,
+    stride_src3: u32,
+
+    // Logical shapes
+    ne0: u32,
+    ne1: u32,
+    ne2: u32,
+#ifdef CLAMP
+    clamp_min: f32,
+    clamp_max: f32,
+#endif
+#ifdef FILL
+    fill_val: f32,
+#endif
+#ifdef XIELU
+    alpha_n: f32,
+    alpha_p: f32,
+    beta: f32,
+    eps: f32,
+#endif
+
+};
+
+@group(0) @binding(PARAMS_BINDING)
+var<uniform> params: Params;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+    if (gid.x >= params.ne) {
+      return;
+    }
+    var i = gid.x;
+    let i3 = i / (params.ne2 * params.ne1 * params.ne0);
+    i = i % (params.ne2 * params.ne1 * params.ne0);
+    let i2 = i / (params.ne1 * params.ne0);
+    i = i % (params.ne1 * params.ne0);
+    let i1 = i / params.ne0;
+    let i0 = i % params.ne0;
+
+    let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
+                  i2 * params.stride_src2 + i3 * params.stride_src3;
+
+#ifdef ABS
+    let res = abs(src[params.offset_src + src_idx]);
+#endif
+#ifdef SGN
+    let res = select(TYPE(select(0.0, -1.0, src[params.offset_src + src_idx] < 0.0)), TYPE(1.0),
+                     src[params.offset_src + src_idx] > 0.0);
+#endif
+#ifdef NEG
+    let res = -src[params.offset_src + src_idx];
+#endif
+#ifdef STEP
+    let res = TYPE(select(0.0, 1.0, src[params.offset_src + src_idx] > 0.0));
+#endif
+#ifdef TANH
+    let res = tanh(clamp(src[params.offset_src + src_idx], -9.010913, 9.010913));
+#endif
+#ifdef RELU
+    let res = select(0.0, src[params.offset_src + src_idx], src[params.offset_src + src_idx] > 0.0);
+#endif
+#ifdef ELU
+    let res = select(exp(src[params.offset_src + src_idx]) - 1.0, src[params.offset_src + src_idx],
+                     src[params.offset_src + src_idx] > 0.0);
+#endif
+#ifdef HARDSIGMOID
+    let res = min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0));
+#endif
+#ifdef SIGMOID
+    let res = 1.0 / (1.0 + exp(-src[params.offset_src + src_idx]));
+#endif
+#ifdef SILU
+    let res = src[params.offset_src + src_idx] / (1.0 + exp(-src[params.offset_src + src_idx]));
+#endif
+#ifdef EXP
+    let res = exp(src[params.offset_src + src_idx]);
+#endif
+#ifdef LOG
+    let res = TYPE(log(f32(src[params.offset_src + src_idx])));
+#endif
+#ifdef CLAMP
+    let res = clamp(src[params.offset_src + src_idx], TYPE(params.clamp_min), TYPE(params.clamp_max));
+#endif
+#ifdef FILL
+    let res = TYPE(params.fill_val);
+#endif
+#ifdef HARDSWISH
+    let res = src[params.offset_src + src_idx] *
+              min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0));
+#endif
+#ifdef GELU
+    let res = 0.5 * src[params.offset_src + src_idx] *
+              (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) *
+                               (src[params.offset_src + src_idx] +
+                                0.044715 * pow(src[params.offset_src + src_idx], 3.0)),
+                               -9.010913, 9.010913)));
+#endif
+#ifdef GELU_QUICK
+    let res = src[params.offset_src + src_idx] * 0.5 *
+              (1.0 + tanh(clamp(0.79788456 *
+                               (src[params.offset_src + src_idx] +
+                                0.044715 * src[params.offset_src + src_idx] *
+                                    src[params.offset_src + src_idx] * src[params.offset_src + src_idx]),
+                               -9.010913, 9.010913)));
+#endif
+#ifdef GELU_ERF
+    let res = 0.5 * src[params.offset_src + src_idx] *
+              (1.0 + tanh(clamp(0.79788456 *
+                               (src[params.offset_src + src_idx] +
+                                0.044715 * src[params.offset_src + src_idx] *
+                                    src[params.offset_src + src_idx] * src[params.offset_src + src_idx]),
+                               -9.010913, 9.010913)));
+#endif
+#ifdef XIELU
+    let res =
+        select(((exp(min(src[params.offset_src + src_idx], TYPE(params.eps))) - 1.0) -
+                src[params.offset_src + src_idx]) *
+                   TYPE(params.alpha_n) +
+               TYPE(params.beta) * src[params.offset_src + src_idx],
+               TYPE(params.alpha_p) * src[params.offset_src + src_idx] *
+                   src[params.offset_src + src_idx] +
+                   TYPE(params.beta) * src[params.offset_src + src_idx],
+               src[params.offset_src + src_idx] > 0.0);
+#endif
+#ifdef SOFTPLUS
+    let src_f32 = f32(src[params.offset_src + src_idx]);
+    let res = TYPE(select(log(1.0 + exp(src_f32)), src_f32, src_f32 > 20.0));
+#endif
+#ifdef EXPM1
+    let res = exp(src[params.offset_src + src_idx]) - 1.0;
+#endif
+#ifdef FLOOR
+    let res = floor(src[params.offset_src + src_idx]);
+#endif
+#ifdef CEIL
+    let res = ceil(src[params.offset_src + src_idx]);
+#endif
+#ifdef ROUND
+    let src_f32 = f32(src[params.offset_src + src_idx]);
+    let result = select(ceil(src_f32 - 0.5), floor(src_f32 + 0.5), src_f32 >= 0.0);
+    let res = TYPE(result);
+#endif
+#ifdef TRUNC
+    let res = trunc(src[params.offset_src + src_idx]);
+#endif
+
+#ifdef INPLACE
+    src[params.offset_src + src_idx] = res;
+#else
+    dst[params.offset_dst + gid.x] = res;
+#endif
+}

+ 0 - 483
ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl

@@ -1,483 +0,0 @@
-#define(REPL_TEMPLATES)
-
-{
-    "XIELU_FUNC": "{{MUTATE}}[dst_i] = select(((exp(min(src[src_i], {{TYPE}}(params.eps))) - 1.0) - src[src_i]) * {{TYPE}}(params.alpha_n) + {{TYPE}}(params.beta) * src[src_i], {{TYPE}}(params.alpha_p) * src[src_i] * src[src_i] + {{TYPE}}(params.beta) * src[src_i], src[src_i] > 0.0);",
-    "ABS_FUNC": "{{MUTATE}}[dst_i] = abs(src[src_i]);",
-    "SGN_FUNC": "{{MUTATE}}[dst_i] = select({{TYPE}}(select(0.0, -1.0, src[src_i] < 0.0)), {{TYPE}}(1.0), src[src_i] > 0.0);",
-    "NEG_FUNC": "{{MUTATE}}[dst_i] = -src[src_i];",
-    "STEP_FUNC": "{{MUTATE}}[dst_i] = {{TYPE}}(select(0.0, 1.0, src[src_i] > 0.0));",
-    "TANH_FUNC": "{{MUTATE}}[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458",
-    "RELU_FUNC": "{{MUTATE}}[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);",
-    "ELU_FUNC": "{{MUTATE}}[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);",
-    "HARDSIGMOID_FUNC": "{{MUTATE}}[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));",
-    "SIGMOID_FUNC": "{{MUTATE}}[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));",
-    "SILU_FUNC": "{{MUTATE}}[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));",
-    "EXP_FUNC": "{{MUTATE}}[dst_i] = exp(src[src_i]);",
-    "HARDSWISH_FUNC": "{{MUTATE}}[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));",
-    "GELU_FUNC": "{{MUTATE}}[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458",
-    "GELU_QUICK_FUNC": "{{MUTATE}}[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458",
-    "GELU_ERF_FUNC": "{{MUTATE}}[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458",
-    "CEIL_FUNC": "{{MUTATE}}[dst_i] = ceil(src[src_i]);"
-}
-
-#end(REPL_TEMPLATES)
-
-#define(VARIANTS)
-
-[
-    {
-      "SHADER_NAME": "abs_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "abs_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "abs_inplace_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-      "SHADER_NAME": "abs_inplace_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-
-    {
-      "SHADER_NAME": "sgn_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "sgn_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "sgn_inplace_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-      "SHADER_NAME": "sgn_inplace_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-
-    {
-      "SHADER_NAME": "neg_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "neg_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "neg_inplace_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-      "SHADER_NAME": "neg_inplace_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-
-    {
-      "SHADER_NAME": "step_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "step_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "step_inplace_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-      "SHADER_NAME": "step_inplace_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-
-    {
-      "SHADER_NAME": "tanh_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "tanh_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "tanh_inplace_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-      "SHADER_NAME": "tanh_inplace_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-
-    {
-      "SHADER_NAME": "elu_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "elu_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "elu_inplace_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-      "SHADER_NAME": "elu_inplace_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-
-    {
-      "SHADER_NAME": "relu_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "relu_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "relu_inplace_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-      "SHADER_NAME": "relu_inplace_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-
-    {
-      "SHADER_NAME": "sigmoid_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "sigmoid_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "sigmoid_inplace_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-      "SHADER_NAME": "sigmoid_inplace_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-
-    {
-      "SHADER_NAME": "silu_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "silu_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "silu_inplace_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-      "SHADER_NAME": "silu_inplace_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-
-    {
-      "SHADER_NAME": "exp_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "exp_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "exp_inplace_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-      "SHADER_NAME": "exp_inplace_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-
-    {
-      "SHADER_NAME": "hardsigmoid_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "hardsigmoid_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "hardsigmoid_inplace_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-      "SHADER_NAME": "hardsigmoid_inplace_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-
-    {
-      "SHADER_NAME": "hardswish_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "hardswish_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "hardswish_inplace_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-      "SHADER_NAME": "hardswish_inplace_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-
-    {
-      "SHADER_NAME": "gelu_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "gelu_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "gelu_inplace_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-      "SHADER_NAME": "gelu_inplace_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-
-    {
-      "SHADER_NAME": "gelu_quick_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "gelu_quick_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "gelu_quick_inplace_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-      "SHADER_NAME": "gelu_quick_inplace_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-
-    {
-      "SHADER_NAME": "xielu_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "xielu_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "xielu_inplace_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-      "SHADER_NAME": "xielu_inplace_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-        "SHADER_NAME": "gelu_erf_f32",
-        "REPLS": { "TYPE": "f32", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-        "DECLS": ["NOT_INPLACE"]
-    },
-    {
-        "SHADER_NAME": "gelu_erf_f16",
-        "REPLS": { "TYPE": "f16", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-        "DECLS": ["NOT_INPLACE"]
-    },
-    {
-        "SHADER_NAME": "gelu_erf_inplace_f32",
-        "REPLS": { "TYPE": "f32", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-        "DECLS": ["INPLACE"]
-    },
-    {
-        "SHADER_NAME": "gelu_erf_inplace_f16",
-        "REPLS": { "TYPE": "f16", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-        "DECLS": ["INPLACE"]
-    },
-
-    {
-        "SHADER_NAME": "ceil_f32",
-        "REPLS": { "TYPE": "f32", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-        "DECLS": ["NOT_INPLACE"]
-    },
-    {
-        "SHADER_NAME": "ceil_f16",
-        "REPLS": { "TYPE": "f16", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-        "DECLS": ["NOT_INPLACE"]
-    },
-    {
-        "SHADER_NAME": "ceil_inplace_f32",
-        "REPLS": { "TYPE": "f32", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-        "DECLS": ["INPLACE"]
-    },
-    {
-        "SHADER_NAME": "ceil_inplace_f16",
-        "REPLS": { "TYPE": "f16", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-        "DECLS": ["INPLACE"]
-    }
-]
-
-#end(VARIANTS)
-
-#define(DECLS)
-
-#decl(INPLACE)
-
-@group(0) @binding(1)
-var<uniform> params: Params;
-
-#enddecl(INPLACE)
-
-#decl(NOT_INPLACE)
-
-@group(0) @binding(1)
-var<storage, read_write> dst: array<{{TYPE}}>;
-
-@group(0) @binding(2)
-var<uniform> params: Params;
-
-#enddecl(NOT_INPLACE)
-
-#end(DECLS)
-
-#define(SHADER)
-
-enable f16;
-
-fn update(dst_i: u32, src_i: u32) {
-    {{FUNC}}
-}
-
-@group(0) @binding(0)
-var<storage, read_write> src: array<{{TYPE}}>;
-
-DECLS
-
-struct Params {
-    ne: u32,            // total number of elements
-    offset_src: u32,    // in elements
-    offset_dst: u32,    // in elements
-
-    // Strides (in elements) — may be permuted
-    stride_src0: u32,
-    stride_src1: u32,
-    stride_src2: u32,
-    stride_src3: u32,
-
-    stride_dst0: u32,
-    stride_dst1: u32,
-    stride_dst2: u32,
-    stride_dst3: u32,
-
-    // Logical shapes
-    src_ne0: u32,
-    src_ne1: u32,
-    src_ne2: u32,
-
-    dst_ne0: u32,
-    dst_ne1: u32,
-    dst_ne2: u32,
-
-    {{EXT_PARAMS}}
-};
-
-override wg_size: u32;
-@compute @workgroup_size(wg_size)
-fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
-    if (gid.x >= params.ne) {
-      return;
-    }
-
-    var i = gid.x;
-    let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0);
-    i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0);
-    let i2 = i / (params.src_ne1 * params.src_ne0);
-    i = i % (params.src_ne1 * params.src_ne0);
-    let i1 = i / params.src_ne0;
-    let i0 = i % params.src_ne0;
-
-    var j = gid.x;
-    let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
-    j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
-    let j2 = j / (params.dst_ne1 * params.dst_ne0);
-    j = j % (params.dst_ne1 * params.dst_ne0);
-    let j1 = j / params.dst_ne0;
-    let j0 = j % params.dst_ne0;
-
-    let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
-                  i2 * params.stride_src2 + i3 * params.stride_src3;
-
-    let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 +
-                  j2 * params.stride_dst2 + j3 * params.stride_dst3;
-
-
-    update(params.offset_dst + dst_idx, params.offset_src + src_idx);
-}
-
-#end(SHADER)
-

Some files were not shown because too many files changed in this diff