فهرست منبع

ggml : sync latest (SAM + SD operators, CUDA alibi) (#2709)

* ggml : sync latest (SAM + SD operators, CUDA alibi)

ggml-ci

* ggml : fix tabs
Georgi Gerganov 2 سال پیش
والد
کامیت
ef3f333d37
6فایلهای تغییر یافته به همراه833 افزوده شده و 87 حذف شده
  1. 2 2
      examples/train-text-from-scratch/train-text-from-scratch.cpp
  2. 2 2
      ggml-alloc.c
  3. 79 0
      ggml-cuda.cu
  4. 630 68
      ggml.c
  5. 109 6
      ggml.h
  6. 11 9
      scripts/sync-ggml.sh

+ 2 - 2
examples/train-text-from-scratch/train-text-from-scratch.cpp

@@ -1868,10 +1868,10 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
         t12->grad = expand(gb, ggml_permute(ctx0, t15->grad, 0, 2, 3, 1));                                            assert_shape_4d(t12->grad, N, n_batch, n_embd/n_head, n_head);
         t11->grad = expand(gb, ggml_reshape_2d(ctx0, ggml_cont(ctx0, t12->grad), N*n_batch, n_embd));                 assert_shape_2d(t11->grad, N*n_batch, n_embd);
         t10->grad = expand(gb, ggml_permute(ctx0, t14->grad, 0, 2, 1, 3));                                            assert_shape_4d(t10->grad, n_embd/n_head, n_head, N, n_batch);
-        t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode, n_ctx));                     assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch);
+        t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode, n_ctx, 10000.0f, 1.0f, 0.0f, false));        assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch);
         t08->grad = expand(gb, ggml_reshape_2d(ctx0, t09->grad, n_embd, N*n_batch));                                  assert_shape_2d(t08->grad, n_embd, N*n_batch);
         t07->grad = expand(gb, ggml_permute(ctx0, t13->grad, 0, 2, 1, 3));                                            assert_shape_4d(t07->grad, n_embd/n_head, n_head, N, n_batch);
-        t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode, n_ctx));                     assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch);
+        t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode, n_ctx, 10000.0f, 1.0f, 0.0f, false));        assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch);
         t05->grad = expand(gb, ggml_reshape_2d(ctx0, t06->grad, n_embd, N*n_batch));                                  assert_shape_2d(t05->grad, n_embd, N*n_batch);
         t04->grad = expand(gb, ggml_add_inplace(ctx0,
                         ggml_add_inplace(ctx0,

+ 2 - 2
ggml-alloc.c

@@ -76,7 +76,7 @@ struct ggml_allocr {
 };
 
 #ifdef GGML_ALLOCATOR_DEBUG
-static void add_allocated_tensor(struct ggml_allocator * alloc, struct ggml_tensor * tensor) {
+static void add_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
     for (int i = 0; i < 1024; i++) {
         if (alloc->allocated_tensors[i] == NULL) {
             alloc->allocated_tensors[i] = tensor;
@@ -85,7 +85,7 @@ static void add_allocated_tensor(struct ggml_allocator * alloc, struct ggml_tens
     }
     GGML_ASSERT(!"out of allocated_tensors");
 }
-static void remove_allocated_tensor(struct ggml_allocator * alloc, struct ggml_tensor * tensor) {
+static void remove_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
     for (int i = 0; i < 1024; i++) {
         if (alloc->allocated_tensors[i] == tensor ||
             (alloc->allocated_tensors[i] != NULL && alloc->allocated_tensors[i]->data == tensor->data)) {

+ 79 - 0
ggml-cuda.cu

@@ -259,6 +259,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 #define CUDA_CPY_BLOCK_SIZE 32
 #define CUDA_SCALE_BLOCK_SIZE 256
 #define CUDA_ROPE_BLOCK_SIZE 256
+#define CUDA_ALIBI_BLOCK_SIZE 32
 #define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
 #define CUDA_QUANTIZE_BLOCK_SIZE 256
 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256
@@ -3940,6 +3941,29 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
     dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
 }
 
+static __global__ void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows,
+                                 const int n_heads_log2_floor, const float m0, const float m1) {
+    const int col = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (col >= ncols) {
+        return;
+    }
+
+    const int row = blockDim.y*blockIdx.y + threadIdx.y;
+    const int i = row*ncols + col;
+
+    const int k = row/k_rows;
+
+    float m_k;
+    if (k < n_heads_log2_floor) {
+        m_k = powf(m0, k + 1);
+    } else {
+        m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
+    }
+
+    dst[i] = col * m_k + x[i];
+}
+
 static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
     const int col = blockDim.x*blockIdx.x + threadIdx.x;
     const int row = blockDim.y*blockIdx.y + threadIdx.y;
@@ -4766,6 +4790,15 @@ static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, con
     rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, block_p, theta_scale);
 }
 
+static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
+                           const int k_rows, const int n_heads_log2_floor, const float m0,
+                           const float m1, cudaStream_t stream) {
+    const dim3 block_dims(CUDA_ALIBI_BLOCK_SIZE, 1, 1);
+    const int num_blocks_x = (ncols + CUDA_ALIBI_BLOCK_SIZE - 1) / (CUDA_ALIBI_BLOCK_SIZE);
+    const dim3 block_nums(num_blocks_x, nrows, 1);
+    alibi_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, k_rows, n_heads_log2_floor, m0, m1);
+}
+
 static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
     const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1);
     const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
@@ -5501,6 +5534,41 @@ inline void ggml_cuda_op_rope(
     (void) i1;
 }
 
+inline void ggml_cuda_op_alibi(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+    cudaStream_t & cudaStream_main){
+
+    GGML_ASSERT(src0_ddf_i != nullptr);
+    GGML_ASSERT(dst_ddf_i != nullptr);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t i01_diff = i01_high - i01_low;
+
+    const int n_past = ((int32_t *) dst->op_params)[0];
+    const int n_head = ((int32_t *) dst->op_params)[1];
+    float max_bias;
+    memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
+
+    GGML_ASSERT(ne01 + n_past == ne00);
+    GGML_ASSERT(n_head == ne02);
+
+    const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
+
+    const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
+
+    // compute
+    alibi_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_heads_log2_floor, m0, m1, cudaStream_main);
+
+    (void) src1;
+    (void) src0_ddq_i;
+    (void) src1_ddf_i;
+    (void) i1;
+}
+
 inline void ggml_cuda_op_diag_mask_inf(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
     float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
@@ -6121,6 +6189,11 @@ void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ten
     ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, !is_glm); // flatten support not implemented for glm
 }
 
+void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_alibi, true, true);
+}
+
 void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     (void) src0;
     (void) src1;
@@ -6456,6 +6529,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
             }
             func = ggml_cuda_rope;
             break;
+        case GGML_OP_ALIBI:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cuda_alibi;
+            break;
         default:
             return false;
     }

تفاوت فایلی نمایش داده نمی شود زیرا این فایل بسیار بزرگ است
+ 630 - 68
ggml.c


+ 109 - 6
ggml.h

@@ -211,6 +211,7 @@
 #define GGML_MAX_OP_PARAMS     32
 #define GGML_DEFAULT_N_THREADS 4
 
+
 #define GGML_EXIT_SUCCESS 0
 #define GGML_EXIT_ABORTED 1
 
@@ -345,10 +346,12 @@ extern "C" {
         GGML_OP_ARGMAX,
         GGML_OP_REPEAT,
         GGML_OP_REPEAT_BACK,
+        GGML_OP_CONCAT,
         GGML_OP_SILU_BACK,
         GGML_OP_NORM, // normalize
         GGML_OP_RMS_NORM,
         GGML_OP_RMS_NORM_BACK,
+        GGML_OP_GROUP_NORM,
 
         GGML_OP_MUL_MAT,
         GGML_OP_OUT_PROD,
@@ -374,14 +377,19 @@ extern "C" {
         GGML_OP_CLAMP,
         GGML_OP_CONV_1D,
         GGML_OP_CONV_2D,
+        GGML_OP_CONV_TRANSPOSE_2D,
         GGML_OP_POOL_1D,
         GGML_OP_POOL_2D,
 
+        GGML_OP_UPSCALE, // nearest interpolate
+
         GGML_OP_FLASH_ATTN,
         GGML_OP_FLASH_FF,
         GGML_OP_FLASH_ATTN_BACK,
         GGML_OP_WIN_PART,
         GGML_OP_WIN_UNPART,
+        GGML_OP_GET_REL_POS,
+        GGML_OP_ADD_REL_POS,
 
         GGML_OP_UNARY,
 
@@ -805,6 +813,13 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    // concat a and b on dim 2
+    // used in stable-diffusion
+    GGML_API struct ggml_tensor * ggml_concat(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
     GGML_API struct ggml_tensor * ggml_abs(
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
@@ -913,6 +928,19 @@ extern "C" {
             struct ggml_tensor  * a,
             float                 eps);
 
+    // group normalize along ne0*ne1*n_groups
+    // used in stable-diffusion
+    // TODO: eps is hardcoded to 1e-6 for now
+    GGML_API struct ggml_tensor * ggml_group_norm(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_groups);
+
+    GGML_API struct ggml_tensor * ggml_group_norm_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_groups);
+
     // a - x
     // b - dy
     // TODO: update with configurable eps
@@ -1213,6 +1241,15 @@ extern "C" {
             float                 freq_base,
             float                 freq_scale);
 
+    // xPos RoPE, in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past,
+            int                   n_dims,
+            float                 base,
+            bool                  down);
+
     // rotary position embedding backward, i.e compute dx from dy
     // a - dy
     GGML_API struct ggml_tensor * ggml_rope_back(
@@ -1221,7 +1258,11 @@ extern "C" {
             int                   n_past,
             int                   n_dims,
             int                   mode,
-            int                   n_ctx);
+            int                   n_ctx,
+            float                 freq_base,
+            float                 freq_scale,
+            float                 xpos_base,
+            bool                  xpos_down);
 
     // alibi position embedding
     // in-place, returns view(a)
@@ -1248,6 +1289,15 @@ extern "C" {
             int                   p0,  // padding
             int                   d0); // dilation
 
+    // conv_1d with padding = half
+    // alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d)
+    GGML_API struct ggml_tensor* ggml_conv_1d_ph(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            int                   s,
+            int                   d);
+
     GGML_API struct ggml_tensor * ggml_conv_2d(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
@@ -1259,14 +1309,38 @@ extern "C" {
             int                   d0,
             int                   d1);
 
-    // conv_1d with padding = half
-    // alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d)
-    GGML_API struct ggml_tensor * ggml_conv_1d_ph(
+
+    // kernel size is a->ne[0] x a->ne[1]
+    // stride is equal to kernel size
+    // padding is zero
+    // example:
+    // a:     16   16    3  768
+    // b:   1024 1024    3    1
+    // res:   64   64  768    1
+    // used in sam
+    GGML_API struct ggml_tensor * ggml_conv_2d_sk_p0(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // kernel size is a->ne[0] x a->ne[1]
+    // stride is 1
+    // padding is half
+    // example:
+    // a:      3    3    256  256
+    // b:     64   64    256    1
+    // res:   64   64    256    1
+    // used in sam
+    GGML_API struct ggml_tensor * ggml_conv_2d_s1_ph(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * b,
-            int                   s,
-            int                   d);
+            int                   stride);
 
     enum ggml_op_pool {
         GGML_OP_POOL_MAX,
@@ -1293,6 +1367,13 @@ extern "C" {
             int                   p0,
             int                   p1);
 
+    // nearest interpolate
+    // used in stable-diffusion
+    GGML_API struct ggml_tensor * ggml_upscale(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   scale_factor);
+
     GGML_API struct ggml_tensor * ggml_flash_attn(
             struct ggml_context * ctx,
             struct ggml_tensor  * q,
@@ -1346,6 +1427,27 @@ extern "C" {
         struct ggml_tensor  * a,
         enum ggml_unary_op op);
 
+    // used in sam
+    GGML_API struct ggml_tensor * ggml_get_rel_pos(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   qh,
+            int                   kh);
+
+    // used in sam
+
+    GGML_API struct ggml_tensor * ggml_add_rel_pos(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * pw,
+            struct ggml_tensor  * ph);
+
+    GGML_API struct ggml_tensor * ggml_add_rel_pos_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * pw,
+            struct ggml_tensor  * ph);
+
     // custom operators
 
     typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
@@ -1500,6 +1602,7 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * tensor);
 
+
     GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
 
     GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);

+ 11 - 9
scripts/sync-ggml.sh

@@ -1,14 +1,16 @@
 #!/bin/bash
 
-cp -rpv ../ggml/src/ggml.c           ./ggml.c
-cp -rpv ../ggml/src/ggml-cuda.h      ./ggml-cuda.h
-cp -rpv ../ggml/src/ggml-cuda.cu     ./ggml-cuda.cu
-cp -rpv ../ggml/src/ggml-opencl.h    ./ggml-opencl.h
-cp -rpv ../ggml/src/ggml-opencl.cpp  ./ggml-opencl.cpp
-cp -rpv ../ggml/src/ggml-metal.h     ./ggml-metal.h
-cp -rpv ../ggml/src/ggml-metal.m     ./ggml-metal.m
-cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal
-cp -rpv ../ggml/include/ggml/ggml.h  ./ggml.h
+cp -rpv ../ggml/src/ggml.c                ./ggml.c
+cp -rpv ../ggml/src/ggml-alloc.c          ./ggml-alloc.c
+cp -rpv ../ggml/src/ggml-cuda.h           ./ggml-cuda.h
+cp -rpv ../ggml/src/ggml-cuda.cu          ./ggml-cuda.cu
+cp -rpv ../ggml/src/ggml-opencl.h         ./ggml-opencl.h
+cp -rpv ../ggml/src/ggml-opencl.cpp       ./ggml-opencl.cpp
+cp -rpv ../ggml/src/ggml-metal.h          ./ggml-metal.h
+cp -rpv ../ggml/src/ggml-metal.m          ./ggml-metal.m
+cp -rpv ../ggml/src/ggml-metal.metal      ./ggml-metal.metal
+cp -rpv ../ggml/include/ggml/ggml.h       ./ggml.h
+cp -rpv ../ggml/include/ggml/ggml-alloc.h ./ggml-alloc.h
 
 cp -rpv ../ggml/tests/test-opt.cpp    ./tests/test-opt.cpp
 cp -rpv ../ggml/tests/test-grad0.cpp  ./tests/test-grad0.cpp

برخی فایل ها در این مقایسه diff نمایش داده نمی شوند زیرا تعداد فایل ها بسیار زیاد است