|
@@ -25,9 +25,9 @@ typedef struct {
|
|
|
} block_q8_0;
|
|
} block_q8_0;
|
|
|
|
|
|
|
|
kernel void kernel_add(
|
|
kernel void kernel_add(
|
|
|
- device const float * src0,
|
|
|
|
|
- device const float * src1,
|
|
|
|
|
- device float * dst,
|
|
|
|
|
|
|
+ device const float4 * src0,
|
|
|
|
|
+ device const float4 * src1,
|
|
|
|
|
+ device float4 * dst,
|
|
|
uint tpig[[thread_position_in_grid]]) {
|
|
uint tpig[[thread_position_in_grid]]) {
|
|
|
dst[tpig] = src0[tpig] + src1[tpig];
|
|
dst[tpig] = src0[tpig] + src1[tpig];
|
|
|
}
|
|
}
|
|
@@ -35,18 +35,18 @@ kernel void kernel_add(
|
|
|
// assumption: src1 is a row
|
|
// assumption: src1 is a row
|
|
|
// broadcast src1 into src0
|
|
// broadcast src1 into src0
|
|
|
kernel void kernel_add_row(
|
|
kernel void kernel_add_row(
|
|
|
- device const float * src0,
|
|
|
|
|
- device const float * src1,
|
|
|
|
|
- device float * dst,
|
|
|
|
|
- constant int64_t & ne00,
|
|
|
|
|
|
|
+ device const float4 * src0,
|
|
|
|
|
+ device const float4 * src1,
|
|
|
|
|
+ device float4 * dst,
|
|
|
|
|
+ constant int64_t & nb,
|
|
|
uint tpig[[thread_position_in_grid]]) {
|
|
uint tpig[[thread_position_in_grid]]) {
|
|
|
- dst[tpig] = src0[tpig] + src1[tpig % ne00];
|
|
|
|
|
|
|
+ dst[tpig] = src0[tpig] + src1[tpig % nb];
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
kernel void kernel_mul(
|
|
kernel void kernel_mul(
|
|
|
- device const float * src0,
|
|
|
|
|
- device const float * src1,
|
|
|
|
|
- device float * dst,
|
|
|
|
|
|
|
+ device const float4 * src0,
|
|
|
|
|
+ device const float4 * src1,
|
|
|
|
|
+ device float4 * dst,
|
|
|
uint tpig[[thread_position_in_grid]]) {
|
|
uint tpig[[thread_position_in_grid]]) {
|
|
|
dst[tpig] = src0[tpig] * src1[tpig];
|
|
dst[tpig] = src0[tpig] * src1[tpig];
|
|
|
}
|
|
}
|
|
@@ -54,12 +54,12 @@ kernel void kernel_mul(
|
|
|
// assumption: src1 is a row
|
|
// assumption: src1 is a row
|
|
|
// broadcast src1 into src0
|
|
// broadcast src1 into src0
|
|
|
kernel void kernel_mul_row(
|
|
kernel void kernel_mul_row(
|
|
|
- device const float * src0,
|
|
|
|
|
- device const float * src1,
|
|
|
|
|
- device float * dst,
|
|
|
|
|
- constant int64_t & ne00,
|
|
|
|
|
|
|
+ device const float4 * src0,
|
|
|
|
|
+ device const float4 * src1,
|
|
|
|
|
+ device float4 * dst,
|
|
|
|
|
+ constant int64_t & nb,
|
|
|
uint tpig[[thread_position_in_grid]]) {
|
|
uint tpig[[thread_position_in_grid]]) {
|
|
|
- dst[tpig] = src0[tpig] * src1[tpig % ne00];
|
|
|
|
|
|
|
+ dst[tpig] = src0[tpig] * src1[tpig % nb];
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
kernel void kernel_scale(
|
|
kernel void kernel_scale(
|