Ver Fonte

Add support for CUMSUM and TRI for CUDA. (#17584)

* Add support for CUMSUM and TRI for CUDA.

* Minor optimizations.

* Correct warp_prefix_inclusive_sum in float2 variant to return float2

* Optimize TRI

* Whitespace

* Fix strides.

* Implement double loop

* Whitespace

* Fix HIP compilation bugs

* Optimizations + big case performance tests

* Implement using CUB with fallback to custom kernel

* Remove error message.

* Fixes from code review

* Comment out CPU-unsupported F16/BF16 cases to fix CI

* Fine, you win :P

* Fix last cast, use NO_DEVICE_CODE and GGML_UNUSED_VARS

* Vary warp-size based on physical warp size

* Add GGML_UNUSED_VARS in tri as well

* Use constexpr and call prefix_inclusive with warp_size template param

* Update ggml/src/ggml-cuda/cumsum.cu

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Apply suggestions from code review

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Change to tid % warp_size

* Fix strides; hardcode mask; add ggml_lane_mask_t

* Missing renames, remove unused get_warp_mask(), explicit calls to ggml_cuda_info()

* Too hasty...

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Piotr Wilkin (ilintar) há 1 mês atrás
pai
commit
96fe9badfc

+ 47 - 0
ggml/src/ggml-cuda/common.cuh

@@ -463,6 +463,53 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
     return x;
 }
 
+template<typename T, int width = WARP_SIZE>
+static __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) {
+    const int lane_id = threadIdx.x % width;
+#pragma unroll
+    for (int offset = 1; offset < width; offset <<= 1) {
+        const T t = __shfl_up_sync(0xffffffff, x, offset, width);
+        if (lane_id >= offset) {
+            x += t;
+        }
+    }
+    return x;
+}
+
+template<int width = WARP_SIZE>
+static __device__ __forceinline__ float2 warp_prefix_inclusive_sum(float2 a) {
+    const int lane_id = threadIdx.x % width;
+#pragma unroll
+    for (int offset = 1; offset < width; offset <<= 1) {
+        const float t_x = __shfl_up_sync(0xffffffff, a.x, offset, width);
+        const float t_y = __shfl_up_sync(0xffffffff, a.y, offset, width);
+        if (lane_id >= offset) {
+            a.x += t_x;
+            a.y += t_y;
+        }
+    }
+    return a;
+}
+
+template<int width = WARP_SIZE>
+static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) {
+#ifdef FP16_AVAILABLE
+    const int lane_id = threadIdx.x % width;
+#pragma unroll
+    for (int offset = 1; offset < width; offset <<= 1) {
+        const half2 t = __shfl_up_sync(0xffffffff, a, offset, width);
+        if (lane_id >= offset) {
+            a = __hadd2(a, t);
+        }
+    }
+    return a;
+
+#else
+    NO_DEVICE_CODE;
+    return a;
+#endif // FP16_AVAILABLE
+}
+
 static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
 #ifdef FP16_AVAILABLE
 

+ 237 - 0
ggml/src/ggml-cuda/cumsum.cu

@@ -0,0 +1,237 @@
+#include <algorithm>
+#include "cumsum.cuh"
+#include "convert.cuh"
+#include "ggml-cuda/common.cuh"
+#include "ggml.h"
+
+#ifdef GGML_CUDA_USE_CUB
+#   include <cub/device/device_scan.cuh>
+#endif // GGML_CUDA_USE_CUB
+
+template<typename T, int BLOCK_SIZE>
+static __global__ void cumsum_cub_kernel(
+        const T * __restrict__ src,
+        T * __restrict__ dst,
+        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+        const int64_t  s01, const int64_t  s02, const int64_t  s03,
+        const int64_t   s1,  const int64_t   s2,  const int64_t   s3) {
+#ifdef GGML_CUDA_USE_CUB
+    using BlockScan = cub::BlockScan<T, BLOCK_SIZE>;
+
+    __shared__ typename BlockScan::TempStorage temp_storage;
+    __shared__ T block_carry;      // carry from previous tile
+
+    const int tid = threadIdx.x;
+
+    const int64_t i1 = blockIdx.x;
+    const int64_t i2 = blockIdx.y;
+    const int64_t i3 = blockIdx.z;
+
+    if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) {
+        return;
+    }
+
+    const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
+    T *       dst_row = dst + i1 * s1  + i2 * s2  + i3 * s3;
+
+    if (tid == 0) {
+        block_carry = 0;
+    }
+    __syncthreads();
+
+    for (int64_t start = 0; start < ne00; start += BLOCK_SIZE) {
+        int64_t idx = start + tid;
+        T x = (idx < ne00) ? src_row[idx] : T(0);
+
+        T inclusive;
+        T block_total;
+        BlockScan(temp_storage).InclusiveSum(x, inclusive, block_total);
+
+        __syncthreads();
+
+        T final_val = inclusive + block_carry;
+
+        // store result
+        if (idx < ne00) {
+            dst_row[idx] = final_val;
+        }
+
+        __syncthreads();
+
+        if (tid == 0) {
+            block_carry += block_total;
+        }
+
+        __syncthreads();
+    }
+#else
+    NO_DEVICE_CODE;
+#endif // GGML_CUDA_USE_CUB
+}
+
+// Fallback kernel implementation (original)
+template<typename T>
+static __global__ void cumsum_kernel(
+        const T * src, T * dst,
+        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+        const int64_t  s00, const int64_t  s01, const int64_t  s02, const int64_t  s03,
+        const int64_t   s0, const int64_t   s1, const int64_t   s2, const int64_t   s3) {
+
+    GGML_UNUSED_VARS(s00, s0);
+
+    const int tid = threadIdx.x;
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+    const int lane = tid % warp_size;
+    const int warp = tid / warp_size;
+    const int warps_per_block = blockDim.x / warp_size;
+
+    extern __shared__ float smem[];
+    float * s_vals = smem;
+    float * s_warp_sums = smem + blockDim.x;
+    float * s_carry = smem + blockDim.x + warps_per_block;
+    float * s_chunk_total = s_carry + 1;
+
+    // Initialize carry
+    if (tid == 0) {
+        *s_carry = 0.0f;
+    }
+    __syncthreads();
+
+    const int64_t i3 = blockIdx.z;
+    const int64_t i2 = blockIdx.y;
+    const int64_t i1 = blockIdx.x;
+    if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
+        return;
+    }
+
+    const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
+    T       * dst_row = dst + i1 * s1  + i2 * s2  + i3 * s3;
+
+    for (int64_t start = 0; start < ne00; start += blockDim.x) {
+        int64_t idx = start + tid;
+        float val = (idx < ne00) ? ggml_cuda_cast<float, T>(src_row[idx]) : 0.0f;
+
+        // 1. Warp inclusive scan
+        val = warp_prefix_inclusive_sum<T, warp_size>(val);
+        s_vals[tid] = val;
+
+        // Store warp total
+        if (lane == warp_size - 1) {
+            s_warp_sums[warp] = val;
+        }
+        __syncthreads();
+
+        // 2. Exclusive scan of warp sums (warp 0 only)
+        if (warp == 0) {
+            float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f;
+            float inc = warp_prefix_inclusive_sum<T, warp_size>(w);
+            if (tid < warps_per_block) {
+                s_warp_sums[tid] = inc - w;   // exclusive sum
+            }
+            if (tid == warps_per_block - 1) {
+                *s_chunk_total = inc;          // total sum of this chunk
+            }
+        }
+        __syncthreads();
+
+        float carry = *s_carry;
+        float final_val = s_vals[tid] + s_warp_sums[warp] + carry;
+        if (idx < ne00) {
+            dst_row[idx] = ggml_cuda_cast<T, float>(final_val);
+        }
+        __syncthreads();
+
+        // Update carry for next chunk
+        if (tid == 0) {
+            *s_carry += *s_chunk_total;
+        }
+        __syncthreads();
+    }
+}
+
+template<typename T>
+static void cumsum_cuda(
+        const T * src, T * dst,
+        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+        const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
+        const int64_t  nb0,  const int64_t nb1, const int64_t  nb2, const int64_t  nb3,
+        cudaStream_t stream) {
+
+    const size_t type_size = sizeof(T);
+    bool use_cub = false;
+#ifdef GGML_CUDA_USE_CUB
+    // Check if we can use CUB (data must be contiguous along innermost dimension)
+    const bool is_contiguous = (nb00 == type_size) && (nb0 == type_size);
+
+    if (is_contiguous) {
+        use_cub = true;
+    }
+#endif // GGML_CUDA_USE_CUB
+    dim3 grid_dims(ne01, ne02, ne03);
+    const auto &info = ggml_cuda_info().devices[ggml_cuda_get_device()];
+    const int warp_size = info.warp_size;
+    const int num_warps = (ne00 + warp_size - 1) / warp_size;
+    int block_size = num_warps * warp_size;
+    block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE);
+    dim3 block_dims(block_size, 1, 1);
+    const int warps_per_block = block_size / warp_size;
+    const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float);
+
+    if (use_cub) {
+        cumsum_cub_kernel<T, CUDA_CUMSUM_BLOCK_SIZE><<<grid_dims, CUDA_CUMSUM_BLOCK_SIZE, 0, stream>>>(
+            src, dst,
+            ne00, ne01, ne02, ne03,
+            nb01 / type_size, nb02 / type_size, nb03 / type_size,
+            nb1 / type_size,  nb2 / type_size,  nb3 / type_size
+        );
+    } else {
+        cumsum_kernel<<<grid_dims, block_dims, shmem_size, stream>>>(
+            src, dst,
+            ne00, ne01, ne02, ne03,
+            nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
+            nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
+        );
+    }
+}
+
+void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == dst->type);
+    switch(src0->type) {
+        case GGML_TYPE_F32:
+            {
+                cumsum_cuda(
+                    (const float *)src0->data, (float *)dst->data,
+                    src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+                    src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+                    dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
+                    stream
+                );
+            } break;
+        // We do not support those on CPU for now anyway, so comment them out because they cause errors on some CI platforms
+        /*case GGML_TYPE_F16:
+            {
+                cumsum_cuda(
+                    (const half *)src0->data, (half *)dst->data,
+                    src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+                    src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+                    dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
+                    stream
+                );
+            } break;
+        case GGML_TYPE_BF16:
+            {
+                cumsum_cuda(
+                    (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,
+                    src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+                    src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+                    dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
+                    stream
+                );
+            } break;*/
+        default:
+            GGML_ABORT("fatal error");
+    }
+}

+ 5 - 0
ggml/src/ggml-cuda/cumsum.cuh

@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_CUMSUM_BLOCK_SIZE 256
+
+void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 10 - 0
ggml/src/ggml-cuda/ggml-cuda.cu

@@ -54,6 +54,8 @@
 #include "ggml-cuda/set-rows.cuh"
 #include "ggml-cuda/pad_reflect_1d.cuh"
 #include "ggml-cuda/solve_tri.cuh"
+#include "ggml-cuda/tri.cuh"
+#include "ggml-cuda/cumsum.cuh"
 #include "ggml.h"
 
 #include <algorithm>
@@ -2701,6 +2703,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_CROSS_ENTROPY_LOSS:
             ggml_cuda_cross_entropy_loss(ctx, dst);
             break;
+        case GGML_OP_CUMSUM:
+            ggml_cuda_op_cumsum(ctx, dst);
+            break;
+        case GGML_OP_TRI:
+            ggml_cuda_op_tri(ctx, dst);
+            break;
         case GGML_OP_RWKV_WKV6:
             ggml_cuda_op_rwkv_wkv6(ctx, dst);
             break;
@@ -4609,6 +4617,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
         case GGML_OP_OPT_STEP_ADAMW:
         case GGML_OP_OPT_STEP_SGD:
+        case GGML_OP_CUMSUM:
+        case GGML_OP_TRI:
             return true;
         case GGML_OP_SOLVE_TRI:
             return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32;

+ 136 - 0
ggml/src/ggml-cuda/tri.cu

@@ -0,0 +1,136 @@
+#include "common.cuh"
+#include "convert.cuh"
+#include "tri.cuh"
+#include "ggml.h"
+
+template<typename T, bool prefix_keep, int add_to_split>
+static __global__ void tri_kernel(
+        const T * src, T * dst,
+        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+        const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
+        const int64_t nb0,  const int64_t nb1,  const int64_t nb2,  const int64_t nb3) {
+    const int64_t i3 = blockIdx.z;
+    const int64_t i2 = blockIdx.y;
+    const int64_t i1 = blockIdx.x;
+    const int64_t split_point = i1 + add_to_split;
+
+    GGML_UNUSED_VARS(nb00, nb0);
+
+    if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
+        return;
+    }
+
+    const T * src_row = src + i1*nb01 + i2*nb02 + i3*nb03;
+    T       * dst_row = dst + i1*nb1  + i2*nb2  + i3*nb3;
+
+    if constexpr (prefix_keep) {
+        for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) {
+            dst_row[i0] = src_row[i0];
+        }
+        for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) {
+            dst_row[i0] = ggml_cuda_cast<T, float>(0.0f);
+        }
+    } else {
+        for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) {
+            dst_row[i0] = ggml_cuda_cast<T, float>(0.0f);
+        }
+        for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) {
+            dst_row[i0] = src_row[i0];
+        }
+    }
+}
+
+template<typename T>
+static void tri_cuda(
+        const T * src, T * dst,
+        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+        const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
+        const int64_t nb0,  const int64_t nb1,  const int64_t nb2,  const int64_t nb3,
+        const ggml_tri_type ttype,
+        cudaStream_t stream) {
+
+    dim3 block_dims(CUDA_TRI_BLOCK_SIZE, 1, 1);
+    dim3 grid_dims(ne01, ne02, ne03);
+    const size_t type_size = sizeof(T);
+
+    const int add_to_split = (ttype == GGML_TRI_TYPE_LOWER_DIAG || ttype == GGML_TRI_TYPE_UPPER) ? 1 : 0;
+    const bool prefix_keep = (ttype == GGML_TRI_TYPE_LOWER || ttype == GGML_TRI_TYPE_LOWER_DIAG);
+
+    if (prefix_keep) {
+        if (add_to_split == 0) {
+            tri_kernel<T, true, 0><<<grid_dims, block_dims, 0, stream>>>(
+                src, dst,
+                ne00, ne01, ne02, ne03,
+                nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
+                nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
+            );
+        } else { // only 0 and 1 supported
+            tri_kernel<T, true, 1><<<grid_dims, block_dims, 0, stream>>>(
+                src, dst,
+                ne00, ne01, ne02, ne03,
+                nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
+                nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
+            );
+        }
+    } else {
+        if (add_to_split == 0) {
+            tri_kernel<T, false, 0><<<grid_dims, block_dims, 0, stream>>>(
+                src, dst,
+                ne00, ne01, ne02, ne03,
+                nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
+                nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
+            );
+        } else {
+            tri_kernel<T, false, 1><<<grid_dims, block_dims, 0, stream>>>(
+                src, dst,
+                ne00, ne01, ne02, ne03,
+                nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
+                nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
+            );
+        }
+    }
+}
+
+void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    cudaStream_t stream = ctx.stream();
+
+    const ggml_tri_type ttype = static_cast<ggml_tri_type>(ggml_get_op_params_i32(dst, 0));
+
+    GGML_ASSERT(src0->type == dst->type);
+
+    switch(src0->type) {
+        case GGML_TYPE_F32:
+            {
+                tri_cuda(
+                    (const float *)src0->data, (float *)dst->data,
+                    src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+                    src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+                    dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
+                    ttype, stream
+                );
+            } break;
+        case GGML_TYPE_F16:
+            {
+                tri_cuda(
+                    (const half *)src0->data, (half *)dst->data,
+                    src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+                    src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+                    dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
+                    ttype, stream
+                );
+            } break;
+        case GGML_TYPE_BF16:
+            {
+                tri_cuda(
+                    (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,
+                    src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+                    src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+                    dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
+                    ttype, stream
+                );
+            } break;
+        default:
+            GGML_ABORT("fatal error");
+    }
+}

+ 5 - 0
ggml/src/ggml-cuda/tri.cuh

@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_TRI_BLOCK_SIZE 256
+
+void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 8 - 0
tests/test-backend-ops.cpp

@@ -7725,6 +7725,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 10, 5, 4, 3 }));
     test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 127, 5, 4, 3 }));
     test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 5, 4, 3 }));
+    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 4 }));
     test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 255, 5, 4, 3 }));
     test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 256, 5, 4, 3 }));
     test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 511, 5, 4, 3 }));
@@ -7954,6 +7955,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
     test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 }));
     test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 }));
 
+    test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F32, { 256, 256, 4, 4 }));
+    test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG, GGML_TYPE_F32, { 1024, 1024, 8, 4 }));
+
+    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 4 }));
+    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2048, 16, 5, 4 }));
+    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 20000, 10, 4, 1 }));
+
     for (int bs : {1, 2, 3, 4, 5, 8, 512}) {
         for (ggml_type type_a : all_types) {
             for (ggml_type type_b : {GGML_TYPE_F32}) {