|
|
@@ -524,6 +524,8 @@ static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong
|
|
|
#define CUDA_SILU_BLOCK_SIZE 256
|
|
|
#define CUDA_TANH_BLOCK_SIZE 256
|
|
|
#define CUDA_RELU_BLOCK_SIZE 256
|
|
|
+#define CUDA_HARDSIGMOID_BLOCK_SIZE 256
|
|
|
+#define CUDA_HARDSWISH_BLOCK_SIZE 256
|
|
|
#define CUDA_SQR_BLOCK_SIZE 256
|
|
|
#define CUDA_CPY_BLOCK_SIZE 32
|
|
|
#define CUDA_SCALE_BLOCK_SIZE 256
|
|
|
@@ -540,6 +542,7 @@ static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong
|
|
|
#define CUDA_PAD_BLOCK_SIZE 256
|
|
|
#define CUDA_ACC_BLOCK_SIZE 256
|
|
|
#define CUDA_IM2COL_BLOCK_SIZE 256
|
|
|
+#define CUDA_POOL2D_BLOCK_SIZE 256
|
|
|
|
|
|
#define CUDA_Q8_0_NE_ALIGN 2048
|
|
|
|
|
|
@@ -823,6 +826,24 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) {
|
|
|
dst[i] = fmaxf(x[i], 0);
|
|
|
}
|
|
|
|
|
|
+static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) {
|
|
|
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
+
|
|
|
+ if (i >= k) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ dst[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
|
|
|
+}
|
|
|
+
|
|
|
+static __global__ void hardswish_f32(const float * x, float * dst, const int k) {
|
|
|
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
+
|
|
|
+ if (i >= k) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
|
|
|
+}
|
|
|
+
|
|
|
static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
|
|
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
if (i >= k) {
|
|
|
@@ -5823,7 +5844,7 @@ static __global__ void alibi_f32(const float * x, float * dst, const int ncols,
|
|
|
}
|
|
|
|
|
|
static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
|
|
|
- const int row = blockIdx.y;
|
|
|
+ const int row = blockIdx.x;
|
|
|
const int col = threadIdx.x;
|
|
|
|
|
|
float sum = 0.0f;
|
|
|
@@ -6145,9 +6166,10 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
|
|
|
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
|
|
|
}
|
|
|
|
|
|
-static __global__ void im2col_f32_f16(
|
|
|
- const float * x, half * dst,
|
|
|
- int offset_delta, int IW, int IH, int OW, int KW, int KH, int pelements, int CHW,
|
|
|
+template <typename T>
|
|
|
+static __global__ void im2col_kernel(
|
|
|
+ const float * x, T * dst, int batch_offset,
|
|
|
+ int offset_delta, int IC, int IW, int IH, int OH, int OW, int KW, int KH, int pelements, int CHW,
|
|
|
int s0, int s1, int p0, int p1, int d0, int d1) {
|
|
|
const int i = threadIdx.x + blockIdx.x * blockDim.x;
|
|
|
if (i >= pelements) {
|
|
|
@@ -6160,21 +6182,73 @@ static __global__ void im2col_f32_f16(
|
|
|
const int ky = (i - kd) / OW;
|
|
|
const int ix = i % OW;
|
|
|
|
|
|
+ const int oh = blockIdx.y;
|
|
|
+ const int batch = blockIdx.z / IC;
|
|
|
+ const int ic = blockIdx.z % IC;
|
|
|
+
|
|
|
const int64_t iiw = ix * s0 + kx * d0 - p0;
|
|
|
- const int64_t iih = blockIdx.y * s1 + ky * d1 - p1;
|
|
|
+ const int64_t iih = oh * s1 + ky * d1 - p1;
|
|
|
|
|
|
const int64_t offset_dst =
|
|
|
- (blockIdx.y * OW + ix) * CHW +
|
|
|
- (blockIdx.z * (KW * KH) + ky * KW + kx);
|
|
|
+ ((batch * OH + oh) * OW + ix) * CHW +
|
|
|
+ (ic * (KW * KH) + ky * KW + kx);
|
|
|
|
|
|
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
|
|
- dst[offset_dst] = __float2half(0.0f);
|
|
|
+ dst[offset_dst] = 0.0f;
|
|
|
} else {
|
|
|
- const int64_t offset_src = blockIdx.z * offset_delta;
|
|
|
- dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
|
|
|
+ const int64_t offset_src = ic * offset_delta + batch * batch_offset;
|
|
|
+ dst[offset_dst] = x[offset_src + iih * IW + iiw];
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+template <typename Ti, typename To>
|
|
|
+static __global__ void pool2d_nchw_kernel(
|
|
|
+ const int ih, const int iw, const int oh, const int ow,
|
|
|
+ const int kh, const int kw, const int sh, const int sw,
|
|
|
+ const int ph, const int pw, const int parallel_elements,
|
|
|
+ const Ti* src, To* dst, const enum ggml_op_pool op) {
|
|
|
+ int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
|
|
+ if (idx >= parallel_elements) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ const int I_HW = ih * iw;
|
|
|
+ const int O_HW = oh * ow;
|
|
|
+ const int nc = idx / O_HW;
|
|
|
+ const int cur_oh = idx % O_HW / ow;
|
|
|
+ const int cur_ow = idx % O_HW % ow;
|
|
|
+ const Ti* i_ptr = src + nc * I_HW;
|
|
|
+ To* o_ptr = dst + nc * O_HW;
|
|
|
+ const int start_h = cur_oh * sh - ph;
|
|
|
+ const int bh = max(0, start_h);
|
|
|
+ const int eh = min(ih, start_h + kh);
|
|
|
+ const int start_w = cur_ow * sw - pw;
|
|
|
+ const int bw = max(0, start_w);
|
|
|
+ const int ew = min(iw, start_w + kw);
|
|
|
+ const To scale = 1. / (kh * kw);
|
|
|
+ To res = 0;
|
|
|
+
|
|
|
+ switch (op) {
|
|
|
+ case GGML_OP_POOL_AVG: res = 0; break;
|
|
|
+ case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
|
|
|
+ }
|
|
|
+
|
|
|
+ for (int i = bh; i < eh; i += 1) {
|
|
|
+ for (int j = bw; j < ew; j += 1) {
|
|
|
+ #if __CUDA_ARCH__ >= 350
|
|
|
+ Ti cur = __ldg(i_ptr + i * iw + j);
|
|
|
+ #else
|
|
|
+ Ti cur = i_ptr[i * iw + j];
|
|
|
+ #endif
|
|
|
+ switch (op) {
|
|
|
+ case GGML_OP_POOL_AVG: res += cur * scale; break;
|
|
|
+ case GGML_OP_POOL_MAX: res = max(res, (To)cur); break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ o_ptr[cur_oh * ow + cur_ow] = res;
|
|
|
+}
|
|
|
+
|
|
|
template<int qk, int qr, dequantize_kernel_t dq>
|
|
|
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
|
|
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
|
|
|
@@ -6388,6 +6462,16 @@ static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
|
|
|
relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
}
|
|
|
|
|
|
+static void hardsigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
|
+ const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE;
|
|
|
+ hardsigmoid_f32<<<num_blocks, CUDA_HARDSIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
+}
|
|
|
+
|
|
|
+static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
|
+ const int num_blocks = (k + CUDA_HARDSWISH_BLOCK_SIZE - 1) / CUDA_HARDSWISH_BLOCK_SIZE;
|
|
|
+ hardswish_f32<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
+}
|
|
|
+
|
|
|
static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
|
|
|
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
|
|
|
leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
|
|
|
@@ -7475,7 +7559,7 @@ static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const
|
|
|
|
|
|
static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
|
|
const dim3 block_dims(WARP_SIZE, 1, 1);
|
|
|
- const dim3 block_nums(1, nrows, 1);
|
|
|
+ const dim3 block_nums(nrows, 1, 1);
|
|
|
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
|
|
|
}
|
|
|
|
|
|
@@ -7587,14 +7671,15 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-static void im2col_f32_f16_cuda(const float* x, half* dst,
|
|
|
+template <typename T>
|
|
|
+static void im2col_cuda(const float* x, T* dst,
|
|
|
int IW, int IH, int OW, int OH, int KW, int KH, int IC,
|
|
|
- int offset_delta,
|
|
|
+ int batch, int batch_offset, int offset_delta,
|
|
|
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
|
|
|
const int parallel_elements = OW * KW * KH;
|
|
|
const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
|
|
|
- dim3 block_nums(num_blocks, OH, IC);
|
|
|
- im2col_f32_f16<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, offset_delta, IW, IH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
|
|
|
+ dim3 block_nums(num_blocks, OH, batch * IC);
|
|
|
+ im2col_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
|
|
|
}
|
|
|
|
|
|
// buffer pool for cuda
|
|
|
@@ -8179,6 +8264,34 @@ static void ggml_cuda_op_relu(
|
|
|
(void) src1_dd;
|
|
|
}
|
|
|
|
|
|
+static void ggml_cuda_op_hardsigmoid(
|
|
|
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
|
|
+ const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
|
|
+
|
|
|
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
|
+
|
|
|
+ hardsigmoid_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
|
|
+
|
|
|
+ (void) src1;
|
|
|
+ (void) dst;
|
|
|
+ (void) src1_dd;
|
|
|
+}
|
|
|
+
|
|
|
+static void ggml_cuda_op_hardswish(
|
|
|
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
|
|
+ const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
|
|
+
|
|
|
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
|
+
|
|
|
+ hardswish_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
|
|
+
|
|
|
+ (void) src1;
|
|
|
+ (void) dst;
|
|
|
+ (void) src1_dd;
|
|
|
+}
|
|
|
+
|
|
|
static void ggml_cuda_op_leaky_relu(
|
|
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
|
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
|
|
@@ -8810,13 +8923,46 @@ static void ggml_cuda_op_alibi(
|
|
|
(void) src1_dd;
|
|
|
}
|
|
|
|
|
|
+static void ggml_cuda_op_pool2d(
|
|
|
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
|
|
+ const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
|
|
+
|
|
|
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
|
+
|
|
|
+ const int32_t * opts = (const int32_t *)dst->op_params;
|
|
|
+ enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
|
|
|
+ const int k0 = opts[1];
|
|
|
+ const int k1 = opts[2];
|
|
|
+ const int s0 = opts[3];
|
|
|
+ const int s1 = opts[4];
|
|
|
+ const int p0 = opts[5];
|
|
|
+ const int p1 = opts[6];
|
|
|
+
|
|
|
+ const int64_t IH = src0->ne[1];
|
|
|
+ const int64_t IW = src0->ne[0];
|
|
|
+
|
|
|
+ const int64_t N = dst->ne[3];
|
|
|
+ const int64_t OC = dst->ne[2];
|
|
|
+ const int64_t OH = dst->ne[1];
|
|
|
+ const int64_t OW = dst->ne[0];
|
|
|
+
|
|
|
+ const int parallel_elements = N * OC * OH * OW;
|
|
|
+ const int num_blocks = (parallel_elements + CUDA_POOL2D_BLOCK_SIZE - 1) / CUDA_POOL2D_BLOCK_SIZE;
|
|
|
+ dim3 block_nums(num_blocks);
|
|
|
+ pool2d_nchw_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, main_stream>>>(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, parallel_elements, src0_dd, dst_dd, op);
|
|
|
+
|
|
|
+ (void) src1;
|
|
|
+ (void) src1_dd;
|
|
|
+}
|
|
|
+
|
|
|
static void ggml_cuda_op_im2col(
|
|
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
|
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
|
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
|
- GGML_ASSERT( dst->type == GGML_TYPE_F16);
|
|
|
+ GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
|
|
|
|
|
|
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
|
|
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
|
|
|
@@ -8838,8 +8984,14 @@ static void ggml_cuda_op_im2col(
|
|
|
const int64_t OW = dst->ne[1];
|
|
|
|
|
|
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
|
|
|
+ const int64_t batch = src1->ne[3];
|
|
|
+ const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
|
|
|
|
|
|
- im2col_f32_f16_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
|
|
|
+ if(dst->type == GGML_TYPE_F16) {
|
|
|
+ im2col_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
|
|
|
+ } else {
|
|
|
+ im2col_cuda(src1_dd, (float*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
|
|
|
+ }
|
|
|
|
|
|
(void) src0;
|
|
|
(void) src0_dd;
|
|
|
@@ -9435,6 +9587,13 @@ static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, g
|
|
|
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu);
|
|
|
}
|
|
|
|
|
|
+static void ggml_cuda_hardsigmoid(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
|
|
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_hardsigmoid);
|
|
|
+}
|
|
|
+
|
|
|
+static void ggml_cuda_hardswish(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
|
|
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_hardswish);
|
|
|
+}
|
|
|
static void ggml_cuda_leaky_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
|
|
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_leaky_relu);
|
|
|
}
|
|
|
@@ -10220,6 +10379,10 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
|
|
|
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
|
|
|
}
|
|
|
|
|
|
+static void ggml_cuda_pool2d(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
|
|
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_pool2d);
|
|
|
+}
|
|
|
+
|
|
|
static void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
|
|
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
|
|
|
}
|
|
|
@@ -10321,6 +10484,12 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
|
|
|
case GGML_UNARY_OP_RELU:
|
|
|
func = ggml_cuda_relu;
|
|
|
break;
|
|
|
+ case GGML_UNARY_OP_HARDSIGMOID:
|
|
|
+ func = ggml_cuda_hardsigmoid;
|
|
|
+ break;
|
|
|
+ case GGML_UNARY_OP_HARDSWISH:
|
|
|
+ func = ggml_cuda_hardswish;
|
|
|
+ break;
|
|
|
default:
|
|
|
return false;
|
|
|
}
|
|
|
@@ -10395,6 +10564,9 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
|
|
|
case GGML_OP_IM2COL:
|
|
|
func = ggml_cuda_im2col;
|
|
|
break;
|
|
|
+ case GGML_OP_POOL_2D:
|
|
|
+ func = ggml_cuda_pool2d;
|
|
|
+ break;
|
|
|
case GGML_OP_SUM_ROWS:
|
|
|
func = ggml_cuda_sum_rows;
|
|
|
break;
|
|
|
@@ -11123,6 +11295,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|
|
case GGML_UNARY_OP_GELU:
|
|
|
case GGML_UNARY_OP_SILU:
|
|
|
case GGML_UNARY_OP_RELU:
|
|
|
+ case GGML_UNARY_OP_HARDSIGMOID:
|
|
|
+ case GGML_UNARY_OP_HARDSWISH:
|
|
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
|
case GGML_UNARY_OP_TANH:
|
|
|
return true;
|
|
|
@@ -11221,6 +11395,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|
|
case GGML_OP_ROPE:
|
|
|
case GGML_OP_ALIBI:
|
|
|
case GGML_OP_IM2COL:
|
|
|
+ case GGML_OP_POOL_2D:
|
|
|
case GGML_OP_SUM_ROWS:
|
|
|
case GGML_OP_ARGSORT:
|
|
|
case GGML_OP_ACC:
|