|
|
@@ -7,6 +7,10 @@
|
|
|
|
|
|
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
|
|
|
|
|
|
+const int CUDA_CPY_TILE_DIM_2D = 32; // 2D tile dimension for transposed blocks
|
|
|
+const int CUDA_CPY_BLOCK_NM = 8; // block size of 3rd dimension if available
|
|
|
+const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows
|
|
|
+
|
|
|
template <cpy_kernel_t cpy_1>
|
|
|
static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
|
|
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
|
@@ -35,6 +39,55 @@ static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
|
|
|
cpy_1(cx + x_offset, cdst + dst_offset);
|
|
|
}
|
|
|
|
|
|
+template <typename T>
|
|
|
+static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int ne,
|
|
|
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
|
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
|
|
+ const int nb12, const int nb13) {
|
|
|
+
|
|
|
+ const T* src = reinterpret_cast<const T*>(cx);
|
|
|
+ T* dst = reinterpret_cast<T*>(cdst);
|
|
|
+
|
|
|
+ const int64_t nmat = ne / (ne00 * ne01);
|
|
|
+ const int64_t n = ne00 * ne01;
|
|
|
+
|
|
|
+ const int x = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.x;
|
|
|
+ const int y = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
|
|
|
+ const int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset
|
|
|
+ const int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
|
|
|
+
|
|
|
+ __shared__ float tile[CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1];
|
|
|
+
|
|
|
+#pragma unroll
|
|
|
+ for (int i = 0; i < CUDA_CPY_BLOCK_NM; ++i) {
|
|
|
+
|
|
|
+ const unsigned int imat = blockIdx.z * CUDA_CPY_BLOCK_NM + i;
|
|
|
+ if (imat >= nmat)
|
|
|
+ break;
|
|
|
+
|
|
|
+#pragma unroll
|
|
|
+ for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {
|
|
|
+ if(x < ne01 && y + j < ne00){
|
|
|
+ const int row = threadIdx.y+j;
|
|
|
+ const int col = threadIdx.x * sizeof(float)/sizeof(T);
|
|
|
+ T *tile2 = reinterpret_cast<T*>(tile[row]);
|
|
|
+ tile2[col] = src[imat*n + (y+j)*ne01 + x];
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ __syncthreads();
|
|
|
+
|
|
|
+#pragma unroll
|
|
|
+ for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {
|
|
|
+ if (ty + j < ne01 && tx < ne00) {
|
|
|
+ const int col = (threadIdx.y+j)*sizeof(float)/sizeof(T);
|
|
|
+ const T *tile2 = reinterpret_cast<const T*>(tile[threadIdx.x]);
|
|
|
+ dst[imat*n + (ty+j)*ne00 + tx] = tile2[col];
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
|
|
|
float * cdstf = (float *)(cdsti);
|
|
|
|
|
|
@@ -136,15 +189,38 @@ cudaStream_t stream) {
|
|
|
(cx, cdst, ne);
|
|
|
}
|
|
|
|
|
|
-template<typename src_t, typename dst_t>
|
|
|
+template<typename src_t, typename dst_t, bool transposed = false>
|
|
|
static void ggml_cpy_flt_cuda(
|
|
|
const char * cx, char * cdst, const int ne,
|
|
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
|
|
|
|
|
- const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
|
|
- cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
|
|
- (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
|
|
+ if (transposed) {
|
|
|
+ GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed
|
|
|
+ int ne00n, ne01n, ne02n;
|
|
|
+ if (nb00 < nb02) {
|
|
|
+ ne00n = ne00;
|
|
|
+ ne01n = ne01;
|
|
|
+ ne02n = ne02;
|
|
|
+ } else if (nb00 > nb02) {
|
|
|
+ ne00n = ne00;
|
|
|
+ ne01n = ne01*ne02;
|
|
|
+ ne02n = 1;
|
|
|
+ } else {
|
|
|
+ GGML_ASSERT(false);
|
|
|
+ }
|
|
|
+
|
|
|
+ dim3 dimGrid( (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
|
|
|
+ (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
|
|
|
+ (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
|
|
|
+ dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
|
|
|
+ cpy_flt_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
|
|
|
+ (cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
|
|
+ } else {
|
|
|
+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
|
|
+ cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
|
|
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
static void ggml_cpy_f32_q8_0_cuda(
|
|
|
@@ -310,6 +386,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
|
|
char * src1_ddc = (char *) src1->data;
|
|
|
|
|
|
const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
|
|
|
+ const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) && src0->ne[3] == 1;
|
|
|
|
|
|
if (src0->type == src1->type && contiguous_srcs) {
|
|
|
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
|
|
|
@@ -322,7 +399,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
|
|
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
|
|
}
|
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
|
|
- ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
+ if (can_be_transposed) {
|
|
|
+ ggml_cpy_flt_cuda<float, float, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
+ } else {
|
|
|
+ ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
+ }
|
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
|
|
if (contiguous_srcs) {
|
|
|
ggml_cpy_flt_contiguous_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
|
|
|
@@ -361,7 +442,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
|
|
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
|
|
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
|
|
- ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
+ if (can_be_transposed) {
|
|
|
+ ggml_cpy_flt_cuda<half, half, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
+ } else {
|
|
|
+ ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
+ }
|
|
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
|
|
|
if (contiguous_srcs) {
|
|
|
ggml_cpy_flt_contiguous_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
|
|
|
@@ -375,7 +460,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
|
|
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
}
|
|
|
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
|
|
|
- ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
+ if (can_be_transposed) {
|
|
|
+ ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
+ } else {
|
|
|
+ ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
+ }
|
|
|
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
|
|
|
if (contiguous_srcs) {
|
|
|
ggml_cpy_flt_contiguous_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, main_stream);
|