|
|
@@ -39,6 +39,8 @@ typedef struct {
|
|
|
int8_t qs[QK8_0]; // quants
|
|
|
} block_q8_0;
|
|
|
|
|
|
+#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
|
|
+
|
|
|
// general-purpose kernel for addition of two tensors
|
|
|
// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
|
|
|
// cons: not very efficient
|
|
|
@@ -180,10 +182,12 @@ kernel void kernel_gelu(
|
|
|
|
|
|
kernel void kernel_soft_max(
|
|
|
device const float * src0,
|
|
|
+ device const float * src1,
|
|
|
device float * dst,
|
|
|
constant int64_t & ne00,
|
|
|
constant int64_t & ne01,
|
|
|
constant int64_t & ne02,
|
|
|
+ constant float & scale,
|
|
|
threadgroup float * buf [[threadgroup(0)]],
|
|
|
uint tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tpitg[[thread_position_in_threadgroup]],
|
|
|
@@ -194,73 +198,77 @@ kernel void kernel_soft_max(
|
|
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
|
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
|
|
|
|
|
- device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
|
- device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
|
+ device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
|
+ device const float * pmask = src1 ? src1 + i01*ne00 : nullptr;
|
|
|
+ device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
|
|
|
|
// parallel max
|
|
|
- float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
|
|
|
+ float lmax = -INFINITY;
|
|
|
|
|
|
- for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
|
|
|
- lmax = MAX(lmax, psrc0[i00]);
|
|
|
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
|
+ lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
|
|
}
|
|
|
|
|
|
- float max = simd_max(lmax);
|
|
|
- if (tiisg == 0) {
|
|
|
- buf[sgitg] = max;
|
|
|
- }
|
|
|
+ // find the max value in the block
|
|
|
+ float max_val = simd_max(lmax);
|
|
|
+ if (ntg > N_SIMDWIDTH) {
|
|
|
+ if (sgitg == 0) {
|
|
|
+ buf[tiisg] = -INFINITY;
|
|
|
+ }
|
|
|
|
|
|
- threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
- // broadcast, simd group number is ntg / 32
|
|
|
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
|
|
- if (tpitg < i) {
|
|
|
- buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
|
|
|
- }
|
|
|
- }
|
|
|
+ if (tiisg == 0) {
|
|
|
+ buf[sgitg] = max_val;
|
|
|
+ }
|
|
|
|
|
|
- threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
- max = buf[0];
|
|
|
+ max_val = buf[tiisg];
|
|
|
+ max_val = simd_max(max_val);
|
|
|
+ }
|
|
|
|
|
|
// parallel sum
|
|
|
float lsum = 0.0f;
|
|
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
|
- const float exp_psrc0 = exp(psrc0[i00] - max);
|
|
|
+ const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
|
|
lsum += exp_psrc0;
|
|
|
- // Remember the result of exp here. exp is expensive, so we really do not
|
|
|
- // wish to compute it twice.
|
|
|
pdst[i00] = exp_psrc0;
|
|
|
}
|
|
|
|
|
|
float sum = simd_sum(lsum);
|
|
|
- if (tiisg == 0) {
|
|
|
- buf[sgitg] = sum;
|
|
|
- }
|
|
|
+ if (ntg > N_SIMDWIDTH) {
|
|
|
+ if (sgitg == 0) {
|
|
|
+ buf[tiisg] = 0.0f;
|
|
|
+ }
|
|
|
|
|
|
- threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
- // broadcast, simd group number is ntg / 32
|
|
|
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
|
|
- if (tpitg < i) {
|
|
|
- buf[tpitg] += buf[tpitg + i];
|
|
|
- }
|
|
|
- }
|
|
|
+ if (tiisg == 0) {
|
|
|
+ buf[sgitg] = sum;
|
|
|
+ }
|
|
|
|
|
|
- threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
+
|
|
|
+ sum = buf[tiisg];
|
|
|
+ sum = simd_sum(sum);
|
|
|
+ }
|
|
|
|
|
|
- sum = buf[0];
|
|
|
+ const float inv_sum = 1.0f/sum;
|
|
|
|
|
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
|
- pdst[i00] /= sum;
|
|
|
+ pdst[i00] *= inv_sum;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
kernel void kernel_soft_max_4(
|
|
|
device const float * src0,
|
|
|
+ device const float * src1,
|
|
|
device float * dst,
|
|
|
constant int64_t & ne00,
|
|
|
constant int64_t & ne01,
|
|
|
constant int64_t & ne02,
|
|
|
+ constant float & scale,
|
|
|
threadgroup float * buf [[threadgroup(0)]],
|
|
|
uint tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tpitg[[thread_position_in_threadgroup]],
|
|
|
@@ -271,64 +279,68 @@ kernel void kernel_soft_max_4(
|
|
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
|
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
|
|
|
|
|
- device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
|
|
- device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
|
|
+ device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
|
|
+ device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
|
|
+ device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
|
|
|
|
|
// parallel max
|
|
|
- float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
|
|
|
+ float4 lmax4 = -INFINITY;
|
|
|
|
|
|
- for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
|
|
|
- lmax4 = fmax(lmax4, psrc4[i00]);
|
|
|
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
|
+ lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
|
|
}
|
|
|
|
|
|
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
|
|
- float max = simd_max(lmax);
|
|
|
- if (tiisg == 0) {
|
|
|
- buf[sgitg] = max;
|
|
|
- }
|
|
|
|
|
|
- threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
+ float max_val = simd_max(lmax);
|
|
|
+ if (ntg > N_SIMDWIDTH) {
|
|
|
+ if (sgitg == 0) {
|
|
|
+ buf[tiisg] = -INFINITY;
|
|
|
+ }
|
|
|
|
|
|
- // broadcast, simd group number is ntg / 32
|
|
|
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
|
|
- if (tpitg < i) {
|
|
|
- buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
|
|
|
- }
|
|
|
- }
|
|
|
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
- threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
+ if (tiisg == 0) {
|
|
|
+ buf[sgitg] = max_val;
|
|
|
+ }
|
|
|
+
|
|
|
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
- max = buf[0];
|
|
|
+ max_val = buf[tiisg];
|
|
|
+ max_val = simd_max(max_val);
|
|
|
+ }
|
|
|
|
|
|
// parallel sum
|
|
|
float4 lsum4 = 0.0f;
|
|
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
|
- const float4 exp_psrc4 = exp(psrc4[i00] - max);
|
|
|
+ const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
|
|
lsum4 += exp_psrc4;
|
|
|
pdst4[i00] = exp_psrc4;
|
|
|
}
|
|
|
|
|
|
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
|
|
float sum = simd_sum(lsum);
|
|
|
- if (tiisg == 0) {
|
|
|
- buf[sgitg] = sum;
|
|
|
- }
|
|
|
+ if (ntg > N_SIMDWIDTH) {
|
|
|
+ if (sgitg == 0) {
|
|
|
+ buf[tiisg] = 0.0f;
|
|
|
+ }
|
|
|
|
|
|
- threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
- // broadcast, simd group number is ntg / 32
|
|
|
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
|
|
- if (tpitg < i) {
|
|
|
- buf[tpitg] += buf[tpitg + i];
|
|
|
- }
|
|
|
- }
|
|
|
+ if (tiisg == 0) {
|
|
|
+ buf[sgitg] = sum;
|
|
|
+ }
|
|
|
|
|
|
- threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
+
|
|
|
+ sum = buf[tiisg];
|
|
|
+ sum = simd_sum(sum);
|
|
|
+ }
|
|
|
|
|
|
- sum = buf[0];
|
|
|
+ const float inv_sum = 1.0f/sum;
|
|
|
|
|
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
|
- pdst4[i00] /= sum;
|
|
|
+ pdst4[i00] *= inv_sum;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -435,14 +447,13 @@ kernel void kernel_rms_norm(
|
|
|
constant int64_t & ne00,
|
|
|
constant uint64_t & nb01,
|
|
|
constant float & eps,
|
|
|
- threadgroup float * sum [[threadgroup(0)]],
|
|
|
+ threadgroup float * buf [[threadgroup(0)]],
|
|
|
uint tgpig[[threadgroup_position_in_grid]],
|
|
|
uint tpitg[[thread_position_in_threadgroup]],
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
|
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
|
uint ntg[[threads_per_threadgroup]]) {
|
|
|
- device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
|
|
- device const float * x_scalar = (device const float *) x;
|
|
|
+ device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
|
|
|
|
|
float4 sumf = 0;
|
|
|
float all_sum = 0;
|
|
|
@@ -453,40 +464,30 @@ kernel void kernel_rms_norm(
|
|
|
}
|
|
|
all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
|
|
|
all_sum = simd_sum(all_sum);
|
|
|
- if (tiisg == 0) {
|
|
|
- sum[sgitg] = all_sum;
|
|
|
- }
|
|
|
+ if (ntg > N_SIMDWIDTH) {
|
|
|
+ if (sgitg == 0) {
|
|
|
+ buf[tiisg] = 0.0f;
|
|
|
+ }
|
|
|
|
|
|
- threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
- // broadcast, simd group number is ntg / 32
|
|
|
- for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
|
|
- if (tpitg < i) {
|
|
|
- sum[tpitg] += sum[tpitg + i];
|
|
|
- }
|
|
|
- }
|
|
|
- if (tpitg == 0) {
|
|
|
- for (int i = 4 * (ne00 / 4); i < ne00; i++) {
|
|
|
- sum[0] += x_scalar[i];
|
|
|
+ if (tiisg == 0) {
|
|
|
+ buf[sgitg] = all_sum;
|
|
|
}
|
|
|
- sum[0] /= ne00;
|
|
|
- }
|
|
|
|
|
|
- threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
- const float mean = sum[0];
|
|
|
+ all_sum = buf[tiisg];
|
|
|
+ all_sum = simd_sum(all_sum);
|
|
|
+ }
|
|
|
+
|
|
|
+ const float mean = all_sum/ne00;
|
|
|
const float scale = 1.0f/sqrt(mean + eps);
|
|
|
|
|
|
device float4 * y = (device float4 *) (dst + tgpig*ne00);
|
|
|
- device float * y_scalar = (device float *) y;
|
|
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
|
y[i00] = x[i00] * scale;
|
|
|
}
|
|
|
- if (tpitg == 0) {
|
|
|
- for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
|
|
|
- y_scalar[i00] = x_scalar[i00] * scale;
|
|
|
- }
|
|
|
- }
|
|
|
}
|
|
|
|
|
|
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
|
|
@@ -576,7 +577,6 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
|
|
// putting them in the kernel cause a significant performance penalty
|
|
|
#define N_DST 4 // each SIMD group works on 4 rows
|
|
|
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
|
|
-#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
|
|
//Note: This is a template, but strictly speaking it only applies to
|
|
|
// quantizations where the block size is 32. It also does not
|
|
|
// giard against the number of rows not being divisible by
|