|
@@ -2520,8 +2520,6 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
|
|
case GGML_OP_CLAMP:
|
|
case GGML_OP_CLAMP:
|
|
|
return op->src[0]->type == GGML_TYPE_F32;
|
|
return op->src[0]->type == GGML_TYPE_F32;
|
|
|
case GGML_OP_SOFT_MAX:
|
|
case GGML_OP_SOFT_MAX:
|
|
|
- // TODO: support attention sinks [TAG_ATTN_SINKS]
|
|
|
|
|
- return op->src[2] == nullptr;
|
|
|
|
|
case GGML_OP_NORM:
|
|
case GGML_OP_NORM:
|
|
|
case GGML_OP_RMS_NORM:
|
|
case GGML_OP_RMS_NORM:
|
|
|
return true;
|
|
return true;
|
|
@@ -6594,17 +6592,24 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
|
|
|
GGML_ASSERT(src1->extra);
|
|
GGML_ASSERT(src1->extra);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ const ggml_tensor * src2 = dst->src[2];
|
|
|
|
|
+ if (src2) {
|
|
|
|
|
+ GGML_ASSERT(src2->extra);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
|
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
|
|
|
|
|
|
|
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
|
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
|
|
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
|
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
|
|
|
|
|
|
|
ggml_tensor_extra_cl * extra1 = src1 ? (ggml_tensor_extra_cl *)src1->extra : nullptr;
|
|
ggml_tensor_extra_cl * extra1 = src1 ? (ggml_tensor_extra_cl *)src1->extra : nullptr;
|
|
|
|
|
+ ggml_tensor_extra_cl * extra2 = src2 ? (ggml_tensor_extra_cl *)src2->extra : nullptr;
|
|
|
|
|
|
|
|
cl_ulong offset0 = extra0->offset + src0->view_offs;
|
|
cl_ulong offset0 = extra0->offset + src0->view_offs;
|
|
|
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
|
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
|
|
|
|
|
|
|
cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
|
|
cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
|
|
|
|
|
+ cl_ulong offset2 = extra2 ? extra2->offset + src2->view_offs : offset0;
|
|
|
|
|
|
|
|
const int ne00 = src0->ne[0];
|
|
const int ne00 = src0->ne[0];
|
|
|
const int ne01 = src0->ne[1];
|
|
const int ne01 = src0->ne[1];
|
|
@@ -6672,25 +6677,27 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
|
|
|
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
|
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
|
|
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), extra1 ? &extra1->data_device : &extra0->data_device));
|
|
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), extra1 ? &extra1->data_device : &extra0->data_device));
|
|
|
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
|
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
|
|
- CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
|
|
|
|
- CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
|
|
|
|
- CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
|
|
|
|
|
- CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
|
|
|
|
|
- CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
|
|
|
|
|
- CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
|
|
|
|
|
- CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
|
|
|
|
|
- CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne13));
|
|
|
|
|
- CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
|
|
|
|
|
- CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
|
|
|
|
|
- CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));
|
|
|
|
|
- CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb1));
|
|
|
|
|
- CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb2));
|
|
|
|
|
- CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb3));
|
|
|
|
|
- CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float), &scale));
|
|
|
|
|
- CL_CHECK(clSetKernelArg(kernel, 19, sizeof(float), &max_bias));
|
|
|
|
|
- CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &m0));
|
|
|
|
|
- CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &m1));
|
|
|
|
|
- CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &n_head_log2));
|
|
|
|
|
|
|
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), extra2 ? &extra2->data_device : &extra0->data_device));
|
|
|
|
|
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
|
|
|
|
|
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
|
|
|
|
|
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
|
|
|
|
|
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
|
|
|
|
|
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
|
|
|
|
|
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
|
|
|
|
|
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
|
|
|
|
|
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12));
|
|
|
|
|
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne13));
|
|
|
|
|
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11));
|
|
|
|
|
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12));
|
|
|
|
|
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb13));
|
|
|
|
|
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb1));
|
|
|
|
|
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb2));
|
|
|
|
|
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb3));
|
|
|
|
|
+ CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &scale));
|
|
|
|
|
+ CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &max_bias));
|
|
|
|
|
+ CL_CHECK(clSetKernelArg(kernel, 22, sizeof(float), &m0));
|
|
|
|
|
+ CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float), &m1));
|
|
|
|
|
+ CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &n_head_log2));
|
|
|
|
|
|
|
|
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
|
|
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
|
|
|
size_t local_work_size[] = {(size_t)nth, 1, 1};
|
|
size_t local_work_size[] = {(size_t)nth, 1, 1};
|