Jelajahi Sumber

opencl: support sink in `soft_max` (attn sinks) (#15152)

lhez 5 bulan lalu
induk
melakukan
aaa3d07ae7

+ 28 - 21
ggml/src/ggml-opencl/ggml-opencl.cpp

@@ -2520,8 +2520,6 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
         case GGML_OP_CLAMP:
             return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_SOFT_MAX:
-            // TODO: support attention sinks [TAG_ATTN_SINKS]
-            return op->src[2] == nullptr;
         case GGML_OP_NORM:
         case GGML_OP_RMS_NORM:
             return true;
@@ -6594,17 +6592,24 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
         GGML_ASSERT(src1->extra);
     }
 
+    const ggml_tensor * src2 = dst->src[2];
+    if (src2) {
+        GGML_ASSERT(src2->extra);
+    }
+
     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
 
     ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
     ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
 
     ggml_tensor_extra_cl * extra1 = src1 ? (ggml_tensor_extra_cl *)src1->extra : nullptr;
+    ggml_tensor_extra_cl * extra2 = src2 ? (ggml_tensor_extra_cl *)src2->extra : nullptr;
 
     cl_ulong offset0 = extra0->offset + src0->view_offs;
     cl_ulong offsetd = extrad->offset + dst->view_offs;
 
     cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
+    cl_ulong offset2 = extra2 ? extra2->offset + src2->view_offs : offset0;
 
     const int ne00 = src0->ne[0];
     const int ne01 = src0->ne[1];
@@ -6672,25 +6677,27 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
     CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
     CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   extra1 ? &extra1->data_device : &extra0->data_device));
     CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
-    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
-    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
-    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
-    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &nb01));
-    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb02));
-    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb03));
-    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12));
-    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne13));
-    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
-    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
-    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));
-    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb1));
-    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb2));
-    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb3));
-    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float),    &scale));
-    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(float),    &max_bias));
-    CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float),    &m0));
-    CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float),    &m1));
-    CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int),      &n_head_log2));
+    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   extra2 ? &extra2->data_device : &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offset2));
+    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_mem),   &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &offsetd));
+    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne00));
+    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb01));
+    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
+    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
+    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne12));
+    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne13));
+    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11));
+    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12));
+    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb13));
+    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb1));
+    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb2));
+    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb3));
+    CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float),    &scale));
+    CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float),    &max_bias));
+    CL_CHECK(clSetKernelArg(kernel, 22, sizeof(float),    &m0));
+    CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float),    &m1));
+    CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int),      &n_head_log2));
 
     size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
     size_t local_work_size[] = {(size_t)nth, 1, 1};

+ 10 - 2
ggml/src/ggml-opencl/kernels/softmax_4_f16.cl

@@ -26,6 +26,8 @@ kernel void kernel_soft_max_4_f16(
         ulong offset0,
         global char * src1,
         ulong offset1,
+        global char * src2,
+        ulong offset2,
         global char * dst,
         ulong offsetd,
         int ne00,
@@ -48,6 +50,7 @@ kernel void kernel_soft_max_4_f16(
 ) {
     src0 = src0 + offset0;
     src1 = src1 + offset1;
+    src2 = src2 + offset2;
     dst  = dst  + offsetd;
 
     int i03 = get_group_id(2);
@@ -60,6 +63,7 @@ kernel void kernel_soft_max_4_f16(
 
     global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
     global half4  * pmask = src1 != src0 ? (global half4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
+    global float  * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
     global float4 * pdst4 = (global float4 *)(dst  + i01*nb1 + i02*nb2 + i03*nb3);
 
     float slope = 1.0f;
@@ -75,7 +79,7 @@ kernel void kernel_soft_max_4_f16(
     }
 
     // parallel max
-    float4 lmax4 = -INFINITY;
+    float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
     for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
         lmax4 = fmax(lmax4, psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f));
     }
@@ -92,7 +96,11 @@ kernel void kernel_soft_max_4_f16(
     }
     float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
 
-    const float sum = sub_group_reduce_add(lsum);
+    float sum = sub_group_reduce_add(lsum);
+
+    if (psrc2) {
+        sum += exp(psrc2[i02] - max);
+    }
 
     for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
         pdst4[i00] /= sum;

+ 10 - 2
ggml/src/ggml-opencl/kernels/softmax_4_f32.cl

@@ -26,6 +26,8 @@ kernel void kernel_soft_max_4(
         ulong offset0,
         global char * src1,
         ulong offset1,
+        global char * src2,
+        ulong offset2,
         global char * dst,
         ulong offsetd,
         int ne00,
@@ -48,6 +50,7 @@ kernel void kernel_soft_max_4(
 ) {
     src0 = src0 + offset0;
     src1 = src1 + offset1;
+    src2 = src2 + offset2;
     dst  = dst  + offsetd;
 
     int i03 = get_group_id(2);
@@ -60,6 +63,7 @@ kernel void kernel_soft_max_4(
 
     global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
     global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
+    global float  * psrc2 = src2 != src0 ? (global float  *)(src2) : 0;
     global float4 * pdst4 = (global float4 *)(dst  + i01*nb1 + i02*nb2 + i03*nb3);
 
     float slope = 1.0f;
@@ -75,7 +79,7 @@ kernel void kernel_soft_max_4(
     }
 
     // parallel max
-    float4 lmax4 = -INFINITY;
+    float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
     for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
         lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
     }
@@ -92,7 +96,11 @@ kernel void kernel_soft_max_4(
     }
     float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
 
-    const float sum = sub_group_reduce_add(lsum);
+    float sum = sub_group_reduce_add(lsum);
+
+    if (psrc2) {
+        sum += exp(psrc2[i02] - max);
+    }
 
     for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
         pdst4[i00] /= sum;

+ 10 - 2
ggml/src/ggml-opencl/kernels/softmax_f16.cl

@@ -26,6 +26,8 @@ kernel void kernel_soft_max_f16(
         ulong offset0,
         global char * src1,
         ulong offset1,
+        global char * src2,
+        ulong offset2,
         global char * dst,
         ulong offsetd,
         int ne00,
@@ -48,6 +50,7 @@ kernel void kernel_soft_max_f16(
 ) {
     src0 = src0 + offset0;
     src1 = src1 + offset1;
+    src2 = src2 + offset2;
     dst  = dst  + offsetd;
 
     int i03 = get_group_id(2);
@@ -60,6 +63,7 @@ kernel void kernel_soft_max_f16(
 
     global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
     global half  * pmask = src1 != src0 ? (global half *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
+    global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
     global float * pdst  = (global float *)(dst  + i01*nb1 + i02*nb2 + i03*nb3);
 
     float slope = 1.0f;
@@ -75,7 +79,7 @@ kernel void kernel_soft_max_f16(
     }
 
     // parallel max
-    float lmax = -INFINITY;
+    float lmax = psrc2 ? psrc2[i02] : -INFINITY;
     for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
         lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
     }
@@ -91,7 +95,11 @@ kernel void kernel_soft_max_f16(
         pdst[i00] = exp_psrc0;
     }
 
-    const float sum = sub_group_reduce_add(lsum);
+    float sum = sub_group_reduce_add(lsum);
+
+    if (psrc2) {
+        sum += exp(psrc2[i02] - max);
+    }
 
     for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
         pdst[i00] /= sum;

+ 10 - 2
ggml/src/ggml-opencl/kernels/softmax_f32.cl

@@ -26,6 +26,8 @@ kernel void kernel_soft_max(
         ulong offset0,
         global char * src1,
         ulong offset1,
+        global char * src2,
+        ulong offset2,
         global char * dst,
         ulong offsetd,
         int ne00,
@@ -48,6 +50,7 @@ kernel void kernel_soft_max(
 ) {
     src0 = src0 + offset0;
     src1 = src1 + offset1;
+    src2 = src2 + offset2;
     dst  = dst  + offsetd;
 
     int i03 = get_group_id(2);
@@ -60,6 +63,7 @@ kernel void kernel_soft_max(
 
     global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
     global float * pmask = src1 != src0 ? (global float *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
+    global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
     global float * pdst  = (global float *)(dst  + i01*nb1 + i02*nb2 + i03*nb3);
 
     float slope = 1.0f;
@@ -75,7 +79,7 @@ kernel void kernel_soft_max(
     }
 
     // parallel max
-    float lmax = -INFINITY;
+    float lmax = psrc2 ? psrc2[i02] : -INFINITY;
     for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
         lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
     }
@@ -91,7 +95,11 @@ kernel void kernel_soft_max(
         pdst[i00] = exp_psrc0;
     }
 
-    const float sum = sub_group_reduce_add(lsum);
+    float sum = sub_group_reduce_add(lsum);
+
+    if (psrc2) {
+        sum += exp(psrc2[i02] - max);
+    }
 
     for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
         pdst[i00] /= sum;