Browse Source

CUDA: Factor out and re-use `block_reduce` function (#18785)

* CUDA: Refactor and expose two_stage_warp_reduce_* function

* Use `two_stage_warp_reduce` also in softmax kernel, move smem out of it

Moving smem out of `__device__` function to `__global__` function
allows for explicit smem reuse, as either compiler or cuda rt seem to not
free it afterwards (`cudaFuncSetAttribute` fails when not accounting for
it once for each call to two_stage_warp_reduce)

* Update ggml/src/ggml-cuda/common.cuh

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

* Use two_stage_warp_reduce in group_norm_f32

* Use two_stage_warp_reduce in rms_norm_f32

* Fix smem calculation which expects bytes

* Make `two_stage_warp_reduce` accept all values warp_reduce accepts

Also integrate it into norm_f32 function

* Use two_stage_warp_reduce in l2_norm_f32

* Use type traits for block reduction for better legibility

Also adresss other requests by @am17an such as variable renaming

* Make norm tests cover all cuda paths

* Mark columns % WARP_SIZE !=0 as supported for RMS_NORM_BACK

Unit-tests passed locally, let's see if they pass in the CI as well

* Use `enum class` for `block_reduce_method`

This is more type-safe than plain enum

* Rename variables as suggested in code review by @am17an

* Rename two_stage_warp_reduce -> block_reduce

* Fix trailing whitespace in common.cuh

* Make condition of static_assert type-dependent

This delays evaluation until the template is actually instantiated.
Otherwise, some compilers may evaluate the assert when parsing the
template, resulting in build errors as observed here:

https://github.com/ggml-org/llama.cpp/actions/runs/20960323123/job/60235530068?pr=18785

* Inline definitions

---------

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
Oliver Simons 1 week ago
parent
commit
36f0132464

+ 80 - 0
ggml/src/ggml-cuda/common.cuh

@@ -530,6 +530,86 @@ static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) {
 #endif // FP16_AVAILABLE
 }
 
+enum class block_reduce_method {
+    MAX,
+    SUM,
+};
+
+template<block_reduce_method method_t, typename T>
+struct block_reduce_policy;
+
+template <typename T, typename... Ts>
+inline constexpr bool is_any = (std::is_same_v<T, Ts> || ...);
+
+template<typename...>
+inline constexpr bool ggml_cuda_dependent_false_v = false;
+
+template <typename T> struct block_reduce_policy<block_reduce_method::SUM, T> {
+    static __device__ T reduce(T val) {
+        if constexpr(is_any<T, float, float2, half2, int>) {
+            return warp_reduce_sum(val);
+        } else {
+            static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce sum");
+        }
+    }
+
+    static __device__ T sentinel() {
+        if constexpr (std::is_same_v<T, float>) {
+            return 0.0f;
+        } else if constexpr (std::is_same_v<T, float2>) {
+            return make_float2(0.0f, 0.0f);
+        } else if constexpr (std::is_same_v<T, half2>) {
+            return make_half2(0.0f, 0.0f);
+        } else if constexpr (std::is_same_v<T, int>) {
+            return 0;
+        } else {
+            static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce sum");
+        }
+    }
+};
+
+template <typename T> struct block_reduce_policy<block_reduce_method::MAX, T> {
+    static __device__ T reduce(T val) {
+        if constexpr (is_any<T, float, half2>) {
+            return warp_reduce_max(val);
+        } else {
+            static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce max");
+        }
+    }
+
+    static __device__ T sentinel() {
+        if constexpr (std::is_same_v<T, float>) {
+            return -INFINITY;
+        } else if constexpr (std::is_same_v<T, half2>) {
+            return make_half2(-INFINITY, -INFINITY);
+        } else {
+            static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce max");
+        }
+    }
+};
+
+template <block_reduce_method reduce_method_t, const unsigned int block_size_template = 0, typename T>
+static __device__ T block_reduce(T val, T * shared_vals) {
+    val                           = block_reduce_policy<reduce_method_t, T>::reduce(val);
+    const unsigned int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
+    if (block_size > WARP_SIZE) {
+        assert((block_size <= 1024) && (block_size % WARP_SIZE) == 0);
+        const int warp_id = threadIdx.x / WARP_SIZE;
+        const int lane_id = threadIdx.x % WARP_SIZE;
+        if (lane_id == 0) {
+            shared_vals[warp_id] = val;
+        }
+        __syncthreads();
+        val = block_reduce_policy<reduce_method_t, T>::sentinel();
+        if (lane_id < (static_cast<int>(block_size) / WARP_SIZE)) {
+            val = shared_vals[lane_id];
+        }
+        return block_reduce_policy<reduce_method_t, T>::reduce(val);
+    }
+
+    return val;
+}
+
 static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
 #ifdef FP16_AVAILABLE
 

+ 1 - 1
ggml/src/ggml-cuda/ggml-cuda.cu

@@ -4551,7 +4551,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_L2_NORM:
             return true;
         case GGML_OP_RMS_NORM_BACK:
-            return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
+            return ggml_is_contiguous(op->src[0]);
             break;
         case GGML_OP_NONE:
         case GGML_OP_RESHAPE:

+ 18 - 76
ggml/src/ggml-cuda/norm.cu

@@ -25,19 +25,8 @@ static __global__ void norm_f32(
     }
 
     // sum up partial sums
-    mean_var = warp_reduce_sum(mean_var);
-    if constexpr (block_size > WARP_SIZE) {
-        static_assert(block_size == 1024, "unexpected block_size");
-        __shared__ float2 s_sum[32];
-        const int warp_id = threadIdx.x / WARP_SIZE;
-        const int lane_id = threadIdx.x % WARP_SIZE;
-        if (lane_id == 0) {
-            s_sum[warp_id] = mean_var;
-        }
-        __syncthreads();
-        mean_var = s_sum[lane_id];
-        mean_var = warp_reduce_sum(mean_var);
-    }
+    extern __shared__ float2 s_sum2[];
+    mean_var = block_reduce<block_reduce_method::SUM, block_size>(mean_var, s_sum2);
 
     const float mean = mean_var.x / ncols;
     const float var = mean_var.y / ncols - mean * mean;
@@ -61,19 +50,8 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
         tmp += x[j];
     }
 
-    tmp = warp_reduce_sum(tmp);
-    if constexpr (block_size > WARP_SIZE) {
-        static_assert(block_size == 1024, "unexpected block_size");
-        __shared__ float s_sum[32];
-        const int warp_id = threadIdx.x / WARP_SIZE;
-        const int lane_id = threadIdx.x % WARP_SIZE;
-        if (lane_id == 0) {
-            s_sum[warp_id] = tmp;
-        }
-        __syncthreads();
-        tmp = s_sum[lane_id];
-        tmp = warp_reduce_sum(tmp);
-    }
+    extern __shared__ float s_sum[];
+    tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
 
     const float mean = tmp / group_size;
     tmp = 0.0f;
@@ -84,18 +62,7 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
         tmp += xi * xi;
     }
 
-    tmp = warp_reduce_sum(tmp);
-    if (block_size > WARP_SIZE) {
-        __shared__ float s_sum[32];
-        const int warp_id = threadIdx.x / WARP_SIZE;
-        const int lane_id = threadIdx.x % WARP_SIZE;
-        if (lane_id == 0) {
-            s_sum[warp_id] = tmp;
-        }
-        __syncthreads();
-        tmp = s_sum[lane_id];
-        tmp = warp_reduce_sum(tmp);
-    }
+    tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
 
     const float variance = tmp / group_size;
     const float scale = rsqrtf(variance + eps);
@@ -163,22 +130,8 @@ static __global__ void rms_norm_f32(const float * x,
     }
 
     // sum up partial sums
-    tmp = warp_reduce_sum(tmp);
-    if constexpr (block_size > WARP_SIZE) {
-        static_assert((block_size <= 1024) && (block_size % 32 == 0), "unexpected block_size");
-        __shared__ float s_sum[32];
-        const int        warp_id = tid / WARP_SIZE;
-        const int        lane_id = tid % WARP_SIZE;
-        if (lane_id == 0) {
-            s_sum[warp_id] = tmp;
-        }
-        __syncthreads();
-        tmp = 0.0f;
-        if (lane_id < (block_size / WARP_SIZE)) {
-            tmp = s_sum[lane_id];
-        }
-        tmp = warp_reduce_sum(tmp);
-    }
+    extern __shared__ float s_sum[];
+    tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
 
     const float mean = tmp / ncols;
     const float scale = rsqrtf(mean + eps);
@@ -306,19 +259,8 @@ static __global__ void l2_norm_f32(
     }
 
     // sum up partial sums
-    tmp = warp_reduce_sum(tmp);
-    if constexpr (block_size > WARP_SIZE) {
-        static_assert(block_size == 1024, "unexpected block_size");
-        __shared__ float s_sum[32];
-        const int warp_id = threadIdx.x / WARP_SIZE;
-        const int lane_id = threadIdx.x % WARP_SIZE;
-        if (lane_id == 0) {
-            s_sum[warp_id] = tmp;
-        }
-        __syncthreads();
-        tmp = s_sum[lane_id];
-        tmp = warp_reduce_sum(tmp);
-    }
+    extern __shared__ float s_sum[];
+    tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
 
     // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
     const float scale = rsqrtf(fmaxf(tmp, eps * eps));
@@ -337,7 +279,7 @@ static void norm_f32_cuda(
         norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
     } else {
         const dim3 block_dims(1024, 1, 1);
-        norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+        norm_f32<1024><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float2): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
     }
 }
 
@@ -348,7 +290,7 @@ static void group_norm_f32_cuda(
         group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
     } else {
         const dim3 block_dims(1024, 1, 1);
-        group_norm_f32<1024><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
+        group_norm_f32<1024><<<num_groups, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, group_size, ne_elements, eps);
     }
 }
 
@@ -358,10 +300,10 @@ static void rms_norm_f32_cuda(
     const dim3 blocks_num(nrows, nchannels, nsamples);
     if (ncols < 1024) {
         const dim3 block_dims(256, 1, 1);
-        rms_norm_f32<256, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+        rms_norm_f32<256, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
     } else {
         const dim3 block_dims(1024, 1, 1);
-        rms_norm_f32<1024, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+        rms_norm_f32<1024, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
     }
 }
 
@@ -404,12 +346,12 @@ static void rms_norm_mul_f32_cuda(const float *  x,
         const uint3 mul_nsamples_packed  = init_fastdiv_values(mul_nsamples);
         if (ncols < 1024) {
             const dim3 block_dims(256, 1, 1);
-            rms_norm_f32<256, true><<<blocks_num, block_dims, 0, stream>>>(
+            rms_norm_f32<256, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
                 x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
                 mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
         } else {
             const dim3 block_dims(1024, 1, 1);
-            rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(
+            rms_norm_f32<1024, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
                 x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
                 mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
         }
@@ -425,14 +367,14 @@ static void rms_norm_mul_f32_cuda(const float *  x,
         const uint3 add_nsamples_packed  = init_fastdiv_values(add_nsamples);
         if (ncols < 1024) {
             const dim3 block_dims(256, 1, 1);
-            rms_norm_f32<256, true, true><<<blocks_num, block_dims, 0, stream>>>(
+            rms_norm_f32<256, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
                 x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
                 mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
                 add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
                 add_nchannels_packed, add_nsamples_packed);
         } else {
             const dim3 block_dims(1024, 1, 1);
-            rms_norm_f32<1024, true, true><<<blocks_num, block_dims, 0, stream>>>(
+            rms_norm_f32<1024, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
                 x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
                 mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
                 add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
@@ -460,7 +402,7 @@ static void l2_norm_f32_cuda(
         l2_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
     } else {
         const dim3 block_dims(1024, 1, 1);
-        l2_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+        l2_norm_f32<1024><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
     }
 }
 

+ 2 - 16
ggml/src/ggml-cuda/reduce_rows.cuh

@@ -28,22 +28,8 @@ static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __r
     }
 
     // sum up partial sums
-    sum = warp_reduce_sum(sum);
-    if (blockDim.x > WARP_SIZE) {
-        assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
-        __shared__ float s_sum[32];
-        const int        warp_id = threadIdx.x / WARP_SIZE;
-        const int        lane_id = threadIdx.x % WARP_SIZE;
-        if (lane_id == 0) {
-            s_sum[warp_id] = sum;
-        }
-        __syncthreads();
-        sum = 0.0f;
-        if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
-            sum = s_sum[lane_id];
-        }
-        sum = warp_reduce_sum(sum);
-    }
+    __shared__ float shared_vals[32];
+    sum = block_reduce<block_reduce_method::SUM>(sum, shared_vals);
 
     if (col != 0) {
         return;

+ 7 - 82
ggml/src/ggml-cuda/softmax.cu

@@ -75,9 +75,6 @@ static __global__ void soft_max_f32(
 
     const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
 
-    const int warp_id = threadIdx.x / WARP_SIZE;
-    const int lane_id = threadIdx.x % WARP_SIZE;
-
     const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1);
 
     extern __shared__ float data_soft_max_f32[];
@@ -102,21 +99,7 @@ static __global__ void soft_max_f32(
     }
 
     // find the max value in the block
-    max_val = warp_reduce_max(max_val);
-    if (block_size > WARP_SIZE) {
-        if (warp_id == 0) {
-            buf_iw[lane_id] = -INFINITY;
-        }
-        __syncthreads();
-
-        if (lane_id == 0) {
-            buf_iw[warp_id] = max_val;
-        }
-        __syncthreads();
-
-        max_val = buf_iw[lane_id];
-        max_val = warp_reduce_max(max_val);
-    }
+    max_val = block_reduce<block_reduce_method::MAX, block_size_template>(max_val, buf_iw);
 
     float tmp = 0.0f; // partial sum
 
@@ -134,22 +117,7 @@ static __global__ void soft_max_f32(
     }
 
     // find the sum of exps in the block
-    tmp = warp_reduce_sum(tmp);
-    if (block_size > WARP_SIZE) {
-        __syncthreads();
-        if (warp_id == 0) {
-            buf_iw[lane_id] = 0.0f;
-        }
-        __syncthreads();
-
-        if (lane_id == 0) {
-            buf_iw[warp_id] = tmp;
-        }
-        __syncthreads();
-
-        tmp = buf_iw[lane_id];
-        tmp = warp_reduce_sum(tmp);
-    }
+    tmp = block_reduce<block_reduce_method::SUM, block_size_template>(tmp, buf_iw);
 
     if (sinks) {
         tmp += expf(sinks[i02] - max_val);
@@ -169,50 +137,6 @@ static __global__ void soft_max_f32(
     }
 }
 
-
-// TODO: This is a common pattern used across kernels that could be moved to common.cuh + templated
-static __device__ float two_stage_warp_reduce_max(float val) {
-    val = warp_reduce_max(val);
-    if (blockDim.x > WARP_SIZE) {
-        assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
-        __shared__ float local_vals[32];
-        const int        warp_id = threadIdx.x / WARP_SIZE;
-        const int        lane_id = threadIdx.x % WARP_SIZE;
-        if (lane_id == 0) {
-            local_vals[warp_id] = val;
-        }
-        __syncthreads();
-        val = -INFINITY;
-        if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
-            val = local_vals[lane_id];
-        }
-        return warp_reduce_max(val);
-    } else {
-        return val;
-    }
-}
-
-static __device__ float two_stage_warp_reduce_sum(float val) {
-    val = warp_reduce_sum(val);
-    if (blockDim.x > WARP_SIZE) {
-        assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
-        __shared__ float local_vals[32];
-        const int        warp_id = threadIdx.x / WARP_SIZE;
-        const int        lane_id = threadIdx.x % WARP_SIZE;
-        if (lane_id == 0) {
-            local_vals[warp_id] = val;
-        }
-        __syncthreads();
-        val = 0.0f;
-        if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
-            val = local_vals[lane_id];
-        }
-        return warp_reduce_sum(val);
-    } else {
-        return val;
-    }
-}
-
 // TODO: Template to allow keeping ncols in registers if they fit
 static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __restrict__ x,
                                                                 float * __restrict__ dst,
@@ -230,6 +154,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __
     float     local_vals[n_elem_per_thread] = { -INFINITY, -INFINITY, -INFINITY, -INFINITY };
     float     local_max                     = -INFINITY;
     const int step_size                     = gridDim.x * blockDim.x;
+    __shared__ float shared_vals[32];
 
     // Compute thread-local max
     for (int col = col_start; col < p.ncols;) {
@@ -246,7 +171,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __
     }
 
     // Compute CTA-level max
-    local_max = two_stage_warp_reduce_max(local_max);
+    local_max = block_reduce<block_reduce_method::MAX>(local_max, shared_vals);
 
     // Store CTA-level max to GMEM
     if (tid == 0) {
@@ -261,7 +186,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __
     } else {
         local_max = -INFINITY;
     }
-    local_max = two_stage_warp_reduce_max(local_max);
+    local_max = block_reduce<block_reduce_method::MAX>(local_max, shared_vals);
 
     // Compute softmax dividends, accumulate divisor
     float tmp_expf = 0.0f;
@@ -284,7 +209,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __
     }
 
     // Reduce divisor within CTA
-    tmp_expf = two_stage_warp_reduce_sum(tmp_expf);
+    tmp_expf = block_reduce<block_reduce_method::SUM>(tmp_expf, shared_vals);
 
     // Store CTA-level sum to GMEM
     if (tid == 0) {
@@ -298,7 +223,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __
     } else {
         tmp_expf = 0.0f;
     }
-    tmp_expf = two_stage_warp_reduce_sum(tmp_expf);
+    tmp_expf = block_reduce<block_reduce_method::SUM>(tmp_expf, shared_vals);
 
     // Divide dividend by global sum + store data
     for (int col = col_start; col < p.ncols;) {

+ 17 - 16
tests/test-backend-ops.cpp

@@ -7482,25 +7482,29 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_softcap(GGML_TYPE_F32, {10, 10, 10, 10}, 50.0f));
     test_cases.emplace_back(new test_silu_back());
 
-    for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {
-        for (bool v : {false, true}) {
-            test_cases.emplace_back(new test_norm    (GGML_TYPE_F32, {64, 5, 4, 3}, v, eps));
-            test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, v, eps));
+    for (float eps : { 0.0f, 1e-6f, 1e-4f, 1e-1f }) {
+        for (uint32_t n : { 64, 1025 }) {
+            for (bool v : { false, true }) {
+                test_cases.emplace_back(new test_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps));
+                test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps));
+            }
+            test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, { n, 5, 4, 3 }, eps));
+            test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps));
         }
-        test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
-        test_cases.emplace_back(new test_l2_norm      (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
     }
 
     // in-place tests
     test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, false, 1e-6f, true));
 
-    for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) {
-        test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
-        test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
-        test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
-        test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
-        test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
-        test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
+    for (float eps : { 0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f }) {
+        for (uint32_t n : { 64, 1025 }) {
+            test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false));
+            test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true));
+            test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false));
+            test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true));
+            test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false));
+            test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true));
+        }
     }
     for (uint32_t n : {1, 511, 1025, 8192, 33*512}) {
         for (bool multi_add : {false, true}) {
@@ -7524,9 +7528,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
             }
         }
     }
-
-    test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
-
     for (int64_t d_conv : {3, 4, 9}) {
         for (int64_t d_inner: {1024, 1536, 2048}) {
             test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}));