|
@@ -9,13 +9,13 @@ using namespace ggml_cuda_mma;
|
|
|
|
|
|
|
|
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
|
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
|
|
|
|
|
|
|
-bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols);
|
|
|
|
|
|
|
+bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, bool mul_mat_id);
|
|
|
|
|
|
|
|
template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids>
|
|
template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids>
|
|
|
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
|
|
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
|
|
|
static __global__ void mul_mat_f(
|
|
static __global__ void mul_mat_f(
|
|
|
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
|
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
|
|
- const int ncols, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
|
|
|
|
|
|
|
+ const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
|
|
|
const int stride_col_id, const int stride_row_id,
|
|
const int stride_col_id, const int stride_row_id,
|
|
|
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
|
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
|
|
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
|
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
|
@@ -31,9 +31,20 @@ static __global__ void mul_mat_f(
|
|
|
|
|
|
|
|
const int row0 = blockIdx.x * rows_per_block;
|
|
const int row0 = blockIdx.x * rows_per_block;
|
|
|
|
|
|
|
|
- const int expert_idx = has_ids ? blockIdx.y : 0;
|
|
|
|
|
|
|
+ int expert_idx = 0;
|
|
|
|
|
+ int col_base = 0;
|
|
|
|
|
+
|
|
|
const int channel_dst = has_ids ? 0 : blockIdx.y;
|
|
const int channel_dst = has_ids ? 0 : blockIdx.y;
|
|
|
|
|
|
|
|
|
|
+ if constexpr (has_ids) {
|
|
|
|
|
+ // experts + tiles of ncols_dst are packed in the y dimension
|
|
|
|
|
+ int col_tiles = (ncols_dst_total + cols_per_block - 1) / cols_per_block;
|
|
|
|
|
+ const int nchannels_x = gridDim.y / col_tiles;
|
|
|
|
|
+ const int tile_idx = blockIdx.y / nchannels_x;
|
|
|
|
|
+ expert_idx = blockIdx.y - tile_idx * nchannels_x;
|
|
|
|
|
+ col_base = tile_idx * cols_per_block;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
const int channel_x = has_ids ? expert_idx : (channel_dst / channel_ratio);
|
|
const int channel_x = has_ids ? expert_idx : (channel_dst / channel_ratio);
|
|
|
const int channel_y = channel_dst;
|
|
const int channel_y = channel_dst;
|
|
|
const int sample_dst = blockIdx.z;
|
|
const int sample_dst = blockIdx.z;
|
|
@@ -44,6 +55,14 @@ static __global__ void mul_mat_f(
|
|
|
y += int64_t(sample_y) *stride_sample_y + (has_ids ? 0 : channel_y *stride_channel_y);
|
|
y += int64_t(sample_y) *stride_sample_y + (has_ids ? 0 : channel_y *stride_channel_y);
|
|
|
dst += int64_t(sample_dst)*stride_sample_dst + (has_ids ? 0 : channel_dst*stride_channel_dst);
|
|
dst += int64_t(sample_dst)*stride_sample_dst + (has_ids ? 0 : channel_dst*stride_channel_dst);
|
|
|
|
|
|
|
|
|
|
+ if constexpr (has_ids) {
|
|
|
|
|
+ constexpr int y_stride_scale = std::is_same_v<T, float> ? 1 : 2;
|
|
|
|
|
+ const int64_t col_offset = col_base;
|
|
|
|
|
+ y += col_offset * stride_col_y * y_stride_scale;
|
|
|
|
|
+ dst += col_offset * stride_col_dst;
|
|
|
|
|
+ ids += col_offset * stride_row_id;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
const float2 * y2 = (const float2 *) y;
|
|
const float2 * y2 = (const float2 *) y;
|
|
|
|
|
|
|
|
extern __shared__ char data_mmv[];
|
|
extern __shared__ char data_mmv[];
|
|
@@ -61,12 +80,17 @@ static __global__ void mul_mat_f(
|
|
|
|
|
|
|
|
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
|
|
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
|
|
|
const int j = j0 + threadIdx.y;
|
|
const int j = j0 + threadIdx.y;
|
|
|
- const int32_t * __restrict__ id_row = ids + j*stride_row_id;
|
|
|
|
|
|
|
|
|
|
if (threadIdx.x == 0) {
|
|
if (threadIdx.x == 0) {
|
|
|
slot_map[j] = -1;
|
|
slot_map[j] = -1;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ if (col_base + j >= ncols_dst_total) {
|
|
|
|
|
+ continue;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ const int32_t * __restrict__ id_row = ids + j*stride_row_id;
|
|
|
|
|
+
|
|
|
for (int k = threadIdx.x; k < nchannels_dst; k += warp_size) {
|
|
for (int k = threadIdx.x; k < nchannels_dst; k += warp_size) {
|
|
|
int match = id_row[k*stride_col_id] == expert_idx;
|
|
int match = id_row[k*stride_col_id] == expert_idx;
|
|
|
|
|
|
|
@@ -108,7 +132,8 @@ static __global__ void mul_mat_f(
|
|
|
if constexpr (!has_ids) {
|
|
if constexpr (!has_ids) {
|
|
|
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
|
|
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
|
|
|
} else {
|
|
} else {
|
|
|
- tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f;
|
|
|
|
|
|
|
+ const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
|
|
|
|
|
+ tile_xy[j0*tile_k_padded + threadIdx.x] = valid ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
|
|
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
|
|
@@ -120,7 +145,8 @@ static __global__ void mul_mat_f(
|
|
|
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
|
|
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
|
|
|
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
|
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
|
|
} else {
|
|
} else {
|
|
|
- float2 tmp = j < cols_per_block && slot_map[j] >= 0 ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
|
|
|
|
|
|
|
+ const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
|
|
|
|
|
+ float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
|
|
|
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
|
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
@@ -183,14 +209,14 @@ static __global__ void mul_mat_f(
|
|
|
dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
|
|
dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
|
|
|
} else {
|
|
} else {
|
|
|
const int slot = (j < cols_per_block) ? slot_map[j] : -1;
|
|
const int slot = (j < cols_per_block) ? slot_map[j] : -1;
|
|
|
- if (slot >= 0) {
|
|
|
|
|
|
|
+ if (slot >= 0 && (col_base + j) < ncols_dst_total) {
|
|
|
dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum;
|
|
dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
#else
|
|
#else
|
|
|
GGML_UNUSED_VARS(x, y, ids, dst,
|
|
GGML_UNUSED_VARS(x, y, ids, dst,
|
|
|
- ncols, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
|
|
|
|
+ ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
stride_col_id, stride_row_id,
|
|
stride_col_id, stride_row_id,
|
|
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
|
@@ -201,20 +227,23 @@ static __global__ void mul_mat_f(
|
|
|
template<typename T, int cols_per_block, int nwarps>
|
|
template<typename T, int cols_per_block, int nwarps>
|
|
|
static inline void mul_mat_f_switch_ids(
|
|
static inline void mul_mat_f_switch_ids(
|
|
|
const T * x, const float * y, const int32_t * ids, float * dst,
|
|
const T * x, const float * y, const int32_t * ids, float * dst,
|
|
|
- const int64_t ncols_x, const int64_t nchannels_dst,
|
|
|
|
|
|
|
+ const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst,
|
|
|
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
|
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
|
|
const int64_t stride_col_id, const int64_t stride_row_id,
|
|
const int64_t stride_col_id, const int64_t stride_row_id,
|
|
|
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
|
|
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
|
|
|
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
|
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
|
|
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
|
|
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
|
|
|
if (ids) {
|
|
if (ids) {
|
|
|
- mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
|
|
|
|
|
- (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
|
|
|
|
+ const int64_t col_tiles = (ncols_dst + cols_per_block - 1) / cols_per_block;
|
|
|
|
|
+ dim3 block_nums_ids = block_nums;
|
|
|
|
|
+ block_nums_ids.y *= col_tiles;
|
|
|
|
|
+ mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
|
|
|
|
|
+ (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
|
|
} else {
|
|
} else {
|
|
|
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
|
|
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
|
|
|
- (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
|
|
|
|
+ (x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
|
|
}
|
|
}
|
|
@@ -223,7 +252,8 @@ static inline void mul_mat_f_switch_ids(
|
|
|
template <typename T, int cols_per_block>
|
|
template <typename T, int cols_per_block>
|
|
|
void mul_mat_f_cuda(
|
|
void mul_mat_f_cuda(
|
|
|
const T * x, const float * y, const int32_t * ids, float * dst,
|
|
const T * x, const float * y, const int32_t * ids, float * dst,
|
|
|
- const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
|
|
|
|
|
|
+ const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
|
|
|
|
|
+ const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
|
|
const int64_t stride_col_id, const int64_t stride_row_id,
|
|
const int64_t stride_col_id, const int64_t stride_row_id,
|
|
|
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
|
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
|
|
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
|
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
|
@@ -268,49 +298,49 @@ void mul_mat_f_cuda(
|
|
|
switch (nwarps_best) {
|
|
switch (nwarps_best) {
|
|
|
case 1: {
|
|
case 1: {
|
|
|
mul_mat_f_switch_ids<T, cols_per_block, 1>(
|
|
mul_mat_f_switch_ids<T, cols_per_block, 1>(
|
|
|
- x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
|
|
|
|
+ x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
|
|
} break;
|
|
} break;
|
|
|
case 2: {
|
|
case 2: {
|
|
|
mul_mat_f_switch_ids<T, cols_per_block, 2>(
|
|
mul_mat_f_switch_ids<T, cols_per_block, 2>(
|
|
|
- x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
|
|
|
|
+ x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
|
|
} break;
|
|
} break;
|
|
|
case 3: {
|
|
case 3: {
|
|
|
mul_mat_f_switch_ids<T, cols_per_block, 3>(
|
|
mul_mat_f_switch_ids<T, cols_per_block, 3>(
|
|
|
- x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
|
|
|
|
+ x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
|
|
} break;
|
|
} break;
|
|
|
case 4: {
|
|
case 4: {
|
|
|
mul_mat_f_switch_ids<T, cols_per_block, 4>(
|
|
mul_mat_f_switch_ids<T, cols_per_block, 4>(
|
|
|
- x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
|
|
|
|
+ x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
|
|
} break;
|
|
} break;
|
|
|
case 5: {
|
|
case 5: {
|
|
|
mul_mat_f_switch_ids<T, cols_per_block, 5>(
|
|
mul_mat_f_switch_ids<T, cols_per_block, 5>(
|
|
|
- x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
|
|
|
|
+ x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
|
|
} break;
|
|
} break;
|
|
|
case 6: {
|
|
case 6: {
|
|
|
mul_mat_f_switch_ids<T, cols_per_block, 6>(
|
|
mul_mat_f_switch_ids<T, cols_per_block, 6>(
|
|
|
- x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
|
|
|
|
+ x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
|
|
} break;
|
|
} break;
|
|
|
case 7: {
|
|
case 7: {
|
|
|
mul_mat_f_switch_ids<T, cols_per_block, 7>(
|
|
mul_mat_f_switch_ids<T, cols_per_block, 7>(
|
|
|
- x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
|
|
|
|
+ x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
|
|
} break;
|
|
} break;
|
|
|
case 8: {
|
|
case 8: {
|
|
|
mul_mat_f_switch_ids<T, cols_per_block, 8>(
|
|
mul_mat_f_switch_ids<T, cols_per_block, 8>(
|
|
|
- x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
|
|
|
|
+ x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
|
|
} break;
|
|
} break;
|
|
@@ -332,84 +362,89 @@ static void mul_mat_f_switch_cols_per_block(
|
|
|
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
|
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
|
|
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
|
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
|
|
cudaStream_t stream) {
|
|
cudaStream_t stream) {
|
|
|
- switch (ncols_dst) {
|
|
|
|
|
|
|
+
|
|
|
|
|
+ const int ncols_case = (ids && ncols_dst > 16) ? 16 : ncols_dst;
|
|
|
|
|
+
|
|
|
|
|
+ GGML_ASSERT(ids || ncols_dst <= 16);
|
|
|
|
|
+
|
|
|
|
|
+ switch (ncols_case) {
|
|
|
case 1: {
|
|
case 1: {
|
|
|
- mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
|
|
|
|
|
+ mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
|
} break;
|
|
} break;
|
|
|
case 2: {
|
|
case 2: {
|
|
|
- mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
|
|
|
|
|
+ mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
|
} break;
|
|
} break;
|
|
|
case 3: {
|
|
case 3: {
|
|
|
- mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
|
|
|
|
|
+ mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
|
} break;
|
|
} break;
|
|
|
case 4: {
|
|
case 4: {
|
|
|
- mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
|
|
|
|
|
+ mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
|
} break;
|
|
} break;
|
|
|
case 5: {
|
|
case 5: {
|
|
|
- mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
|
|
|
|
|
+ mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
|
} break;
|
|
} break;
|
|
|
case 6: {
|
|
case 6: {
|
|
|
- mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
|
|
|
|
|
+ mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
|
} break;
|
|
} break;
|
|
|
case 7: {
|
|
case 7: {
|
|
|
- mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
|
|
|
|
|
+ mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
|
} break;
|
|
} break;
|
|
|
case 8: {
|
|
case 8: {
|
|
|
- mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
|
|
|
|
|
+ mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
|
} break;
|
|
} break;
|
|
|
case 9: {
|
|
case 9: {
|
|
|
- mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
|
|
|
|
|
+ mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
|
} break;
|
|
} break;
|
|
|
case 10: {
|
|
case 10: {
|
|
|
- mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
|
|
|
|
|
+ mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
|
} break;
|
|
} break;
|
|
|
case 11: {
|
|
case 11: {
|
|
|
- mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
|
|
|
|
|
+ mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
|
} break;
|
|
} break;
|
|
|
case 12: {
|
|
case 12: {
|
|
|
- mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
|
|
|
|
|
+ mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
|
} break;
|
|
} break;
|
|
|
case 13: {
|
|
case 13: {
|
|
|
- mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
|
|
|
|
|
+ mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
|
} break;
|
|
} break;
|
|
|
case 14: {
|
|
case 14: {
|
|
|
- mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
|
|
|
|
|
+ mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
|
} break;
|
|
} break;
|
|
|
case 15: {
|
|
case 15: {
|
|
|
- mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
|
|
|
|
|
+ mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
|
} break;
|
|
} break;
|
|
|
case 16: {
|
|
case 16: {
|
|
|
- mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
|
|
|
|
|
+ mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
|
} break;
|
|
} break;
|
|
@@ -422,7 +457,7 @@ static void mul_mat_f_switch_cols_per_block(
|
|
|
#define DECL_MMF_CASE_HELPER(T, ncols_dst) \
|
|
#define DECL_MMF_CASE_HELPER(T, ncols_dst) \
|
|
|
template void mul_mat_f_cuda<T, ncols_dst>( \
|
|
template void mul_mat_f_cuda<T, ncols_dst>( \
|
|
|
const T * x, const float * y, const int32_t * ids, float * dst, \
|
|
const T * x, const float * y, const int32_t * ids, float * dst, \
|
|
|
- const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
|
|
|
|
|
|
|
+ const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
|
|
|
const int64_t stride_col_id, const int64_t stride_row_id, \
|
|
const int64_t stride_col_id, const int64_t stride_row_id, \
|
|
|
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \
|
|
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \
|
|
|
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\
|
|
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\
|