|
@@ -1,3 +1,6 @@
|
|
|
|
|
+#include "convert.cuh"
|
|
|
|
|
+#include "ggml-cuda/common.cuh"
|
|
|
|
|
+#include "ggml.h"
|
|
|
#include "rope.cuh"
|
|
#include "rope.cuh"
|
|
|
|
|
|
|
|
struct rope_corr_dims {
|
|
struct rope_corr_dims {
|
|
@@ -37,11 +40,23 @@ static __device__ void rope_yarn(
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-template<bool forward, bool has_ff, typename T>
|
|
|
|
|
-static __global__ void rope_norm(
|
|
|
|
|
- const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
|
|
|
|
|
- const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
|
|
|
|
- const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
|
|
|
|
|
|
|
+template <bool forward, bool has_ff, typename T, typename D>
|
|
|
|
|
+static __global__ void rope_norm(const T * x,
|
|
|
|
|
+ D * dst,
|
|
|
|
|
+ const int ne0,
|
|
|
|
|
+ const int ne1,
|
|
|
|
|
+ const int s1,
|
|
|
|
|
+ const int s2,
|
|
|
|
|
+ const int n_dims,
|
|
|
|
|
+ const int32_t * pos,
|
|
|
|
|
+ const float freq_scale,
|
|
|
|
|
+ const float ext_factor,
|
|
|
|
|
+ const float attn_factor,
|
|
|
|
|
+ const rope_corr_dims corr_dims,
|
|
|
|
|
+ const float theta_scale,
|
|
|
|
|
+ const float * freq_factors,
|
|
|
|
|
+ const int64_t * row_indices,
|
|
|
|
|
+ const int set_rows_stride) {
|
|
|
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
|
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
|
|
|
|
|
|
|
if (i0 >= ne0) {
|
|
if (i0 >= ne0) {
|
|
@@ -53,13 +68,27 @@ static __global__ void rope_norm(
|
|
|
const int row_x = row_dst % ne1;
|
|
const int row_x = row_dst % ne1;
|
|
|
const int channel_x = row_dst / ne1;
|
|
const int channel_x = row_dst / ne1;
|
|
|
|
|
|
|
|
- const int idst = row_dst*ne0 + i0;
|
|
|
|
|
|
|
+ int idst = row_dst * ne0 + i0;
|
|
|
const int ix = channel_x*s2 + row_x*s1 + i0;
|
|
const int ix = channel_x*s2 + row_x*s1 + i0;
|
|
|
|
|
|
|
|
- if (i0 >= n_dims) {
|
|
|
|
|
- dst[idst + 0] = x[ix + 0];
|
|
|
|
|
- dst[idst + 1] = x[ix + 1];
|
|
|
|
|
|
|
+ // Fusion optimization: ROPE + VIEW + SET_ROWS.
|
|
|
|
|
+ // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
|
|
|
|
|
+ if (set_rows_stride != 0) {
|
|
|
|
|
+ idst = row_x * ne0 + i0;
|
|
|
|
|
+ idst += row_indices[channel_x] * set_rows_stride;
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
|
|
+ const auto & store_coaelsced = [&](float x0, float x1) {
|
|
|
|
|
+ if constexpr (std::is_same_v<float, D>) {
|
|
|
|
|
+ float2 v = make_float2(x0, x1);
|
|
|
|
|
+ ggml_cuda_memcpy_1<8>(dst + idst, &v);
|
|
|
|
|
+ } else if constexpr (std::is_same_v<half, D>) {
|
|
|
|
|
+ half2 v = make_half2(x0, x1);
|
|
|
|
|
+ ggml_cuda_memcpy_1<4>(dst + idst, &v);
|
|
|
|
|
+ }
|
|
|
|
|
+ };
|
|
|
|
|
+ if (i0 >= n_dims) {
|
|
|
|
|
+ store_coaelsced(x[ix + 0], x[ix + 1]);
|
|
|
return;
|
|
return;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -75,15 +104,26 @@ static __global__ void rope_norm(
|
|
|
const float x0 = x[ix + 0];
|
|
const float x0 = x[ix + 0];
|
|
|
const float x1 = x[ix + 1];
|
|
const float x1 = x[ix + 1];
|
|
|
|
|
|
|
|
- dst[idst + 0] = x0*cos_theta - x1*sin_theta;
|
|
|
|
|
- dst[idst + 1] = x0*sin_theta + x1*cos_theta;
|
|
|
|
|
|
|
+ store_coaelsced(x0 * cos_theta - x1 * sin_theta, x0 * sin_theta + x1 * cos_theta);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-template<bool forward, bool has_ff, typename T>
|
|
|
|
|
-static __global__ void rope_neox(
|
|
|
|
|
- const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
|
|
|
|
|
- const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
|
|
|
|
- const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
|
|
|
|
|
|
|
+template <bool forward, bool has_ff, typename T, typename D>
|
|
|
|
|
+static __global__ void rope_neox(const T * x,
|
|
|
|
|
+ D * dst,
|
|
|
|
|
+ const int ne0,
|
|
|
|
|
+ const int ne1,
|
|
|
|
|
+ const int s1,
|
|
|
|
|
+ const int s2,
|
|
|
|
|
+ const int n_dims,
|
|
|
|
|
+ const int32_t * pos,
|
|
|
|
|
+ const float freq_scale,
|
|
|
|
|
+ const float ext_factor,
|
|
|
|
|
+ const float attn_factor,
|
|
|
|
|
+ const rope_corr_dims corr_dims,
|
|
|
|
|
+ const float theta_scale,
|
|
|
|
|
+ const float * freq_factors,
|
|
|
|
|
+ const int64_t * row_indices,
|
|
|
|
|
+ const int set_rows_stride) {
|
|
|
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
|
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
|
|
|
|
|
|
|
if (i0 >= ne0) {
|
|
if (i0 >= ne0) {
|
|
@@ -95,12 +135,19 @@ static __global__ void rope_neox(
|
|
|
const int row_x = row_dst % ne1;
|
|
const int row_x = row_dst % ne1;
|
|
|
const int channel_x = row_dst / ne1;
|
|
const int channel_x = row_dst / ne1;
|
|
|
|
|
|
|
|
- const int idst = row_dst*ne0 + i0/2;
|
|
|
|
|
|
|
+ int idst = row_dst * ne0 + i0 / 2;
|
|
|
const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
|
const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
|
|
|
|
|
|
|
|
|
+ // Fusion optimization: ROPE + VIEW + SET_ROWS.
|
|
|
|
|
+ // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
|
|
|
|
|
+ if (set_rows_stride != 0) {
|
|
|
|
|
+ idst = row_x * ne0 + i0 / 2;
|
|
|
|
|
+ idst += row_indices[channel_x] * set_rows_stride;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
if (i0 >= n_dims) {
|
|
if (i0 >= n_dims) {
|
|
|
- dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
|
|
|
|
|
- dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
|
|
|
|
|
|
|
+ dst[idst + i0 / 2 + 0] = ggml_cuda_cast<D>(x[ix + i0 / 2 + 0]);
|
|
|
|
|
+ dst[idst + i0 / 2 + 1] = ggml_cuda_cast<D>(x[ix + i0 / 2 + 1]);
|
|
|
|
|
|
|
|
return;
|
|
return;
|
|
|
}
|
|
}
|
|
@@ -117,8 +164,8 @@ static __global__ void rope_neox(
|
|
|
const float x0 = x[ix + 0];
|
|
const float x0 = x[ix + 0];
|
|
|
const float x1 = x[ix + n_dims/2];
|
|
const float x1 = x[ix + n_dims/2];
|
|
|
|
|
|
|
|
- dst[idst + 0] = x0*cos_theta - x1*sin_theta;
|
|
|
|
|
- dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
|
|
|
|
|
|
+ dst[idst + 0] = ggml_cuda_cast<D>(x0 * cos_theta - x1 * sin_theta);
|
|
|
|
|
+ dst[idst + n_dims / 2] = ggml_cuda_cast<D>(x0 * sin_theta + x1 * cos_theta);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
template<bool forward, bool has_ff, typename T>
|
|
template<bool forward, bool has_ff, typename T>
|
|
@@ -238,11 +285,25 @@ static __global__ void rope_vision(
|
|
|
dst[idst + n_dims] = x0*sin_theta + x1*cos_theta;
|
|
dst[idst + n_dims] = x0*sin_theta + x1*cos_theta;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-template<bool forward, typename T>
|
|
|
|
|
-static void rope_norm_cuda(
|
|
|
|
|
- const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
|
|
|
|
|
- const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
|
|
|
|
- const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
|
|
|
|
|
|
+template <bool forward, typename T, typename D>
|
|
|
|
|
+static void rope_norm_cuda(const T * x,
|
|
|
|
|
+ D * dst,
|
|
|
|
|
+ const int ne0,
|
|
|
|
|
+ const int ne1,
|
|
|
|
|
+ const int s1,
|
|
|
|
|
+ const int s2,
|
|
|
|
|
+ const int n_dims,
|
|
|
|
|
+ const int nr,
|
|
|
|
|
+ const int32_t * pos,
|
|
|
|
|
+ const float freq_scale,
|
|
|
|
|
+ const float freq_base,
|
|
|
|
|
+ const float ext_factor,
|
|
|
|
|
+ const float attn_factor,
|
|
|
|
|
+ const rope_corr_dims corr_dims,
|
|
|
|
|
+ const float * freq_factors,
|
|
|
|
|
+ const int64_t * row_indices,
|
|
|
|
|
+ const int set_rows_stride,
|
|
|
|
|
+ cudaStream_t stream) {
|
|
|
GGML_ASSERT(ne0 % 2 == 0);
|
|
GGML_ASSERT(ne0 % 2 == 0);
|
|
|
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
|
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
|
|
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
|
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
|
@@ -252,20 +313,34 @@ static void rope_norm_cuda(
|
|
|
|
|
|
|
|
if (freq_factors == nullptr) {
|
|
if (freq_factors == nullptr) {
|
|
|
rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>(
|
|
rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>(
|
|
|
- x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
|
|
|
|
- attn_factor, corr_dims, theta_scale, freq_factors);
|
|
|
|
|
|
|
+ x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
|
|
|
|
|
+ freq_factors, row_indices, set_rows_stride);
|
|
|
} else {
|
|
} else {
|
|
|
rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>(
|
|
rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>(
|
|
|
- x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
|
|
|
|
- attn_factor, corr_dims, theta_scale, freq_factors);
|
|
|
|
|
|
|
+ x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
|
|
|
|
|
+ freq_factors, row_indices, set_rows_stride);
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-template<bool forward, typename T>
|
|
|
|
|
-static void rope_neox_cuda(
|
|
|
|
|
- const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
|
|
|
|
|
- const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
|
|
|
|
- const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
|
|
|
|
|
|
+template <bool forward, typename T, typename D>
|
|
|
|
|
+static void rope_neox_cuda(const T * x,
|
|
|
|
|
+ D * dst,
|
|
|
|
|
+ const int ne0,
|
|
|
|
|
+ const int ne1,
|
|
|
|
|
+ const int s1,
|
|
|
|
|
+ const int s2,
|
|
|
|
|
+ const int n_dims,
|
|
|
|
|
+ const int nr,
|
|
|
|
|
+ const int32_t * pos,
|
|
|
|
|
+ const float freq_scale,
|
|
|
|
|
+ const float freq_base,
|
|
|
|
|
+ const float ext_factor,
|
|
|
|
|
+ const float attn_factor,
|
|
|
|
|
+ const rope_corr_dims corr_dims,
|
|
|
|
|
+ const float * freq_factors,
|
|
|
|
|
+ const int64_t * row_indices,
|
|
|
|
|
+ const int set_rows_stride,
|
|
|
|
|
+ cudaStream_t stream) {
|
|
|
GGML_ASSERT(ne0 % 2 == 0);
|
|
GGML_ASSERT(ne0 % 2 == 0);
|
|
|
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
|
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
|
|
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
|
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
|
@@ -274,13 +349,13 @@ static void rope_neox_cuda(
|
|
|
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
|
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
|
|
|
|
|
|
|
if (freq_factors == nullptr) {
|
|
if (freq_factors == nullptr) {
|
|
|
- rope_neox<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
|
|
|
|
|
- x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
|
|
|
|
- attn_factor, corr_dims, theta_scale, freq_factors);
|
|
|
|
|
|
|
+ rope_neox<forward, false><<<block_nums, block_dims, 0, stream>>>(
|
|
|
|
|
+ x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
|
|
|
|
|
+ freq_factors, row_indices, set_rows_stride);
|
|
|
} else {
|
|
} else {
|
|
|
- rope_neox<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
|
|
|
|
|
- x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
|
|
|
|
- attn_factor, corr_dims, theta_scale, freq_factors);
|
|
|
|
|
|
|
+ rope_neox<forward, true><<<block_nums, block_dims, 0, stream>>>(
|
|
|
|
|
+ x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
|
|
|
|
|
+ freq_factors, row_indices, set_rows_stride);
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -333,7 +408,9 @@ static void rope_vision_cuda(
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
template <bool forward>
|
|
template <bool forward>
|
|
|
-void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
|
|
|
+void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
|
|
|
|
|
+ ggml_tensor * dst,
|
|
|
|
|
+ const ggml_tensor * set_rows = nullptr) {
|
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
|
const ggml_tensor * src1 = dst->src[1];
|
|
const ggml_tensor * src1 = dst->src[1];
|
|
|
const ggml_tensor * src2 = dst->src[2];
|
|
const ggml_tensor * src2 = dst->src[2];
|
|
@@ -341,12 +418,25 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
|
|
const float * src0_d = (const float *)src0->data;
|
|
const float * src0_d = (const float *)src0->data;
|
|
|
const float * src1_d = (const float *)src1->data;
|
|
const float * src1_d = (const float *)src1->data;
|
|
|
|
|
|
|
|
- float * dst_d = (float *)dst->data;
|
|
|
|
|
|
|
+ void * dst_d = dst->data;
|
|
|
|
|
+ const int64_t * row_indices = nullptr;
|
|
|
|
|
+ ggml_type dst_type = dst->type;
|
|
|
|
|
+ int set_rows_stride = 0;
|
|
|
|
|
+
|
|
|
|
|
+ if (set_rows != nullptr) {
|
|
|
|
|
+ GGML_ASSERT(forward);
|
|
|
|
|
+ dst_d = set_rows->data;
|
|
|
|
|
+ row_indices = (const int64_t *) set_rows->src[1]->data;
|
|
|
|
|
+ dst_type = set_rows->type;
|
|
|
|
|
+ set_rows_stride = set_rows->nb[1] / ggml_type_size(set_rows->type);
|
|
|
|
|
+ }
|
|
|
cudaStream_t stream = ctx.stream();
|
|
cudaStream_t stream = ctx.stream();
|
|
|
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
|
- GGML_ASSERT(src0->type == dst->type);
|
|
|
|
|
|
|
+ // When not fused, src0 and dst types must match
|
|
|
|
|
+ // When fused (ROPE+VIEW+SET_ROWS), src0 may be F32 and dst may be F16
|
|
|
|
|
+ GGML_ASSERT(src0->type == dst->type || (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16));
|
|
|
|
|
|
|
|
const int64_t ne00 = src0->ne[0]; // head dims
|
|
const int64_t ne00 = src0->ne[0]; // head dims
|
|
|
const int64_t ne01 = src0->ne[1]; // num heads
|
|
const int64_t ne01 = src0->ne[1]; // num heads
|
|
@@ -404,14 +494,18 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
|
|
|
|
|
|
|
// compute
|
|
// compute
|
|
|
if (is_neox) {
|
|
if (is_neox) {
|
|
|
- if (src0->type == GGML_TYPE_F32) {
|
|
|
|
|
- rope_neox_cuda<forward>(
|
|
|
|
|
- (const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
|
|
|
|
|
- freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
|
|
|
|
- } else if (src0->type == GGML_TYPE_F16) {
|
|
|
|
|
- rope_neox_cuda<forward>(
|
|
|
|
|
- (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
|
|
|
|
|
- freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
|
|
|
|
|
|
+ if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
|
|
|
|
|
+ rope_neox_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
|
|
|
|
|
+ nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
|
|
|
|
+ freq_factors, row_indices, set_rows_stride, stream);
|
|
|
|
|
+ } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
|
|
|
|
|
+ rope_neox_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
|
|
|
|
|
+ nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
|
|
|
|
+ freq_factors, row_indices, set_rows_stride, stream);
|
|
|
|
|
+ } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
|
|
|
|
|
+ rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
|
|
|
|
|
+ pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
|
|
|
|
+ freq_factors, row_indices, set_rows_stride, stream);
|
|
|
} else {
|
|
} else {
|
|
|
GGML_ABORT("fatal error");
|
|
GGML_ABORT("fatal error");
|
|
|
}
|
|
}
|
|
@@ -440,14 +534,18 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
|
|
GGML_ABORT("fatal error");
|
|
GGML_ABORT("fatal error");
|
|
|
}
|
|
}
|
|
|
} else {
|
|
} else {
|
|
|
- if (src0->type == GGML_TYPE_F32) {
|
|
|
|
|
- rope_norm_cuda<forward>(
|
|
|
|
|
- (const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
|
|
|
|
|
- freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
|
|
|
|
- } else if (src0->type == GGML_TYPE_F16) {
|
|
|
|
|
- rope_norm_cuda<forward>(
|
|
|
|
|
- (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
|
|
|
|
|
- freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
|
|
|
|
|
|
+ if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
|
|
|
|
|
+ rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
|
|
|
|
|
+ nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
|
|
|
|
+ freq_factors, row_indices, set_rows_stride, stream);
|
|
|
|
|
+ } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
|
|
|
|
|
+ rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
|
|
|
|
|
+ nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
|
|
|
|
+ freq_factors, row_indices, set_rows_stride, stream);
|
|
|
|
|
+ } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
|
|
|
|
|
+ rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
|
|
|
|
|
+ pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
|
|
|
|
+ freq_factors, row_indices, set_rows_stride, stream);
|
|
|
} else {
|
|
} else {
|
|
|
GGML_ABORT("fatal error");
|
|
GGML_ABORT("fatal error");
|
|
|
}
|
|
}
|
|
@@ -461,3 +559,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
ggml_cuda_op_rope_impl<false>(ctx, dst);
|
|
ggml_cuda_op_rope_impl<false>(ctx, dst);
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+void ggml_cuda_op_rope_fused(ggml_backend_cuda_context & ctx, ggml_tensor * rope, ggml_tensor * set_rows) {
|
|
|
|
|
+ ggml_cuda_op_rope_impl<true>(ctx, rope, set_rows);
|
|
|
|
|
+}
|