|
@@ -16,9 +16,10 @@ static __device__ float rope_yarn_ramp(const float low, const float high, const
|
|
|
|
|
|
|
|
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
|
|
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
|
|
|
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
|
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
|
|
|
|
+template<bool forward>
|
|
|
static __device__ void rope_yarn(
|
|
static __device__ void rope_yarn(
|
|
|
- float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
|
|
|
|
|
- float * cos_theta, float * sin_theta) {
|
|
|
|
|
|
|
+ const float theta_extrap, const float freq_scale, const rope_corr_dims corr_dims, const int64_t i0, const float ext_factor,
|
|
|
|
|
+ float mscale, float & cos_theta, float & sin_theta) {
|
|
|
// Get n-d rotational scaling corrected for extrapolation
|
|
// Get n-d rotational scaling corrected for extrapolation
|
|
|
float theta_interp = freq_scale * theta_extrap;
|
|
float theta_interp = freq_scale * theta_extrap;
|
|
|
float theta = theta_interp;
|
|
float theta = theta_interp;
|
|
@@ -29,24 +30,28 @@ static __device__ void rope_yarn(
|
|
|
// Get n-d magnitude scaling corrected for interpolation
|
|
// Get n-d magnitude scaling corrected for interpolation
|
|
|
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
|
|
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
|
|
|
}
|
|
}
|
|
|
- *cos_theta = cosf(theta) * mscale;
|
|
|
|
|
- *sin_theta = sinf(theta) * mscale;
|
|
|
|
|
|
|
+ cos_theta = cosf(theta) * mscale;
|
|
|
|
|
+ sin_theta = sinf(theta) * mscale;
|
|
|
|
|
+ if (!forward) {
|
|
|
|
|
+ sin_theta *= -1.0f;
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-template<typename T, bool has_ff>
|
|
|
|
|
|
|
+template<bool forward, bool has_ff, typename T>
|
|
|
static __global__ void rope_norm(
|
|
static __global__ void rope_norm(
|
|
|
- const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
|
|
|
|
|
- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
|
|
|
|
|
|
|
+ const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
|
|
|
|
|
+ const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
|
|
|
|
+ const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors) {
|
|
|
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
|
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
|
|
|
|
|
|
|
if (i0 >= ne0) {
|
|
if (i0 >= ne0) {
|
|
|
return;
|
|
return;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
|
|
+ const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
|
|
|
if (i0 >= n_dims) {
|
|
if (i0 >= n_dims) {
|
|
|
- const int i = row*ne0 + i0;
|
|
|
|
|
|
|
+ const int i = row_dst*ne0 + i0;
|
|
|
|
|
|
|
|
dst[i + 0] = x[i + 0];
|
|
dst[i + 0] = x[i + 0];
|
|
|
dst[i + 1] = x[i + 1];
|
|
dst[i + 1] = x[i + 1];
|
|
@@ -54,39 +59,43 @@ static __global__ void rope_norm(
|
|
|
return;
|
|
return;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- const int i = row*ne0 + i0;
|
|
|
|
|
- const int i2 = row/p_delta_rows;
|
|
|
|
|
|
|
+ const int row_x = row_dst % ne1;
|
|
|
|
|
+ const int channel_x = row_dst / ne1;
|
|
|
|
|
+
|
|
|
|
|
+ const int idst = row_dst*ne0 + i0;
|
|
|
|
|
+ const int ix = channel_x*s2 + row_x*s1 + i0;
|
|
|
|
|
|
|
|
- const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
|
|
|
|
|
|
|
+ const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
|
|
|
|
|
|
|
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
|
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
|
|
|
|
|
|
|
float cos_theta;
|
|
float cos_theta;
|
|
|
float sin_theta;
|
|
float sin_theta;
|
|
|
|
|
|
|
|
- rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
|
|
|
|
|
|
+ rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
|
|
|
|
|
|
|
|
- const float x0 = x[i + 0];
|
|
|
|
|
- const float x1 = x[i + 1];
|
|
|
|
|
|
|
+ const float x0 = x[ix + 0];
|
|
|
|
|
+ const float x1 = x[ix + 1];
|
|
|
|
|
|
|
|
- dst[i + 0] = x0*cos_theta - x1*sin_theta;
|
|
|
|
|
- dst[i + 1] = x0*sin_theta + x1*cos_theta;
|
|
|
|
|
|
|
+ dst[idst + 0] = x0*cos_theta - x1*sin_theta;
|
|
|
|
|
+ dst[idst + 1] = x0*sin_theta + x1*cos_theta;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-template<typename T, bool has_ff>
|
|
|
|
|
|
|
+template<bool forward, bool has_ff, typename T>
|
|
|
static __global__ void rope_neox(
|
|
static __global__ void rope_neox(
|
|
|
- const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
|
|
|
|
|
- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
|
|
|
|
|
|
|
+ const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
|
|
|
|
|
+ const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
|
|
|
|
+ const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors) {
|
|
|
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
|
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
|
|
|
|
|
|
|
if (i0 >= ne0) {
|
|
if (i0 >= ne0) {
|
|
|
return;
|
|
return;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
|
|
+ const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
|
|
|
if (i0 >= n_dims) {
|
|
if (i0 >= n_dims) {
|
|
|
- const int i = row*ne0 + i0;
|
|
|
|
|
|
|
+ const int i = row_dst*ne0 + i0;
|
|
|
|
|
|
|
|
dst[i + 0] = x[i + 0];
|
|
dst[i + 0] = x[i + 0];
|
|
|
dst[i + 1] = x[i + 1];
|
|
dst[i + 1] = x[i + 1];
|
|
@@ -94,39 +103,43 @@ static __global__ void rope_neox(
|
|
|
return;
|
|
return;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- const int i = row*ne0 + i0/2;
|
|
|
|
|
- const int i2 = row/p_delta_rows;
|
|
|
|
|
|
|
+ const int row_x = row_dst % ne1;
|
|
|
|
|
+ const int channel_x = row_dst / ne1;
|
|
|
|
|
+
|
|
|
|
|
+ const int idst = row_dst*ne0 + i0/2;
|
|
|
|
|
+ const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
|
|
|
|
|
|
|
- const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
|
|
|
|
|
|
|
+ const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
|
|
|
|
|
|
|
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
|
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
|
|
|
|
|
|
|
float cos_theta;
|
|
float cos_theta;
|
|
|
float sin_theta;
|
|
float sin_theta;
|
|
|
|
|
|
|
|
- rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
|
|
|
|
|
|
+ rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
|
|
|
|
|
|
|
|
- const float x0 = x[i + 0];
|
|
|
|
|
- const float x1 = x[i + n_dims/2];
|
|
|
|
|
|
|
+ const float x0 = x[ix + 0];
|
|
|
|
|
+ const float x1 = x[ix + n_dims/2];
|
|
|
|
|
|
|
|
- dst[i + 0] = x0*cos_theta - x1*sin_theta;
|
|
|
|
|
- dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
|
|
|
|
|
|
+ dst[idst + 0] = x0*cos_theta - x1*sin_theta;
|
|
|
|
|
+ dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-template<typename T, bool has_ff>
|
|
|
|
|
|
|
+template<bool forward, bool has_ff, typename T>
|
|
|
static __global__ void rope_multi(
|
|
static __global__ void rope_multi(
|
|
|
- const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
|
|
|
|
|
- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
|
|
|
|
|
|
|
+ const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
|
|
|
|
|
+ const int n_dims, const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
|
|
|
|
+ const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors, const mrope_sections sections) {
|
|
|
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
|
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
|
|
|
|
|
|
|
if (i0 >= ne0) {
|
|
if (i0 >= ne0) {
|
|
|
return;
|
|
return;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
|
|
+ const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
|
|
|
if (i0 >= n_dims) {
|
|
if (i0 >= n_dims) {
|
|
|
- const int i = row*ne0 + i0;
|
|
|
|
|
|
|
+ const int i = row_dst*ne0 + i0;
|
|
|
|
|
|
|
|
dst[i + 0] = x[i + 0];
|
|
dst[i + 0] = x[i + 0];
|
|
|
dst[i + 1] = x[i + 1];
|
|
dst[i + 1] = x[i + 1];
|
|
@@ -134,25 +147,28 @@ static __global__ void rope_multi(
|
|
|
return;
|
|
return;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- const int i = row*ne0 + i0/2;
|
|
|
|
|
- const int i2 = row/p_delta_rows;
|
|
|
|
|
|
|
+ const int row_x = row_dst % ne1;
|
|
|
|
|
+ const int channel_x = row_dst / ne1;
|
|
|
|
|
|
|
|
- int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
|
|
|
|
|
- int sec_w = sections.v[1] + sections.v[0];
|
|
|
|
|
- int sector = (i0 / 2) % sect_dims;
|
|
|
|
|
|
|
+ const int idst = row_dst*ne0 + i0/2;
|
|
|
|
|
+ const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
|
|
|
|
+
|
|
|
|
|
+ const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
|
|
|
|
|
+ const int sec_w = sections.v[1] + sections.v[0];
|
|
|
|
|
+ const int sector = (i0 / 2) % sect_dims;
|
|
|
|
|
|
|
|
float theta_base = 0.0;
|
|
float theta_base = 0.0;
|
|
|
if (sector < sections.v[0]) {
|
|
if (sector < sections.v[0]) {
|
|
|
- theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
|
|
|
|
|
|
|
+ theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
|
|
}
|
|
}
|
|
|
else if (sector >= sections.v[0] && sector < sec_w) {
|
|
else if (sector >= sections.v[0] && sector < sec_w) {
|
|
|
- theta_base = pos[i2 + ne2 * 1]*powf(theta_scale, i0/2.0f);
|
|
|
|
|
|
|
+ theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
|
|
|
}
|
|
}
|
|
|
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
|
|
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
|
|
|
- theta_base = pos[i2 + ne2 * 2]*powf(theta_scale, i0/2.0f);
|
|
|
|
|
|
|
+ theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
|
|
|
}
|
|
}
|
|
|
else if (sector >= sec_w + sections.v[2]) {
|
|
else if (sector >= sec_w + sections.v[2]) {
|
|
|
- theta_base = pos[i2 + ne2 * 3]*powf(theta_scale, i0/2.0f);
|
|
|
|
|
|
|
+ theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
|
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
|
@@ -160,42 +176,46 @@ static __global__ void rope_multi(
|
|
|
float cos_theta;
|
|
float cos_theta;
|
|
|
float sin_theta;
|
|
float sin_theta;
|
|
|
|
|
|
|
|
- rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
|
|
|
|
|
|
+ rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
|
|
|
|
|
|
|
|
- const float x0 = x[i + 0];
|
|
|
|
|
- const float x1 = x[i + n_dims/2];
|
|
|
|
|
|
|
+ const float x0 = x[ix + 0];
|
|
|
|
|
+ const float x1 = x[ix + n_dims/2];
|
|
|
|
|
|
|
|
- dst[i + 0] = x0*cos_theta - x1*sin_theta;
|
|
|
|
|
- dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
|
|
|
|
|
|
+ dst[idst + 0] = x0*cos_theta - x1*sin_theta;
|
|
|
|
|
+ dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-template<typename T, bool has_ff>
|
|
|
|
|
|
|
+template<bool forward, bool has_ff, typename T>
|
|
|
static __global__ void rope_vision(
|
|
static __global__ void rope_vision(
|
|
|
- const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
|
|
|
|
|
- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
|
|
|
|
|
|
|
+ const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims,
|
|
|
|
|
+ const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
|
|
|
|
|
+ const float theta_scale, const float * __restrict__ freq_factors, const mrope_sections sections) {
|
|
|
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
|
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
|
|
|
|
|
|
|
if (i0 >= ne0) {
|
|
if (i0 >= ne0) {
|
|
|
return;
|
|
return;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
|
|
+ const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
+
|
|
|
|
|
+ const int row_x = row_dst % ne1;
|
|
|
|
|
+ const int channel_x = row_dst / ne1;
|
|
|
|
|
|
|
|
- const int i = row*ne0 + i0/2;
|
|
|
|
|
- const int i2 = row/p_delta_rows; // i2-th tokens
|
|
|
|
|
|
|
+ const int idst = row_dst*ne0 + i0/2;
|
|
|
|
|
+ const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
|
|
|
|
|
|
|
- int sect_dims = sections.v[0] + sections.v[1];
|
|
|
|
|
- int sec_w = sections.v[1] + sections.v[0];
|
|
|
|
|
- int sector = (i0 / 2) % sect_dims;
|
|
|
|
|
|
|
+ const int sect_dims = sections.v[0] + sections.v[1];
|
|
|
|
|
+ const int sec_w = sections.v[1] + sections.v[0];
|
|
|
|
|
+ const int sector = (i0 / 2) % sect_dims;
|
|
|
|
|
|
|
|
float theta_base = 0.0;
|
|
float theta_base = 0.0;
|
|
|
if (sector < sections.v[0]) {
|
|
if (sector < sections.v[0]) {
|
|
|
const int p = sector;
|
|
const int p = sector;
|
|
|
- theta_base = pos[i2]*powf(theta_scale, p);
|
|
|
|
|
|
|
+ theta_base = pos[channel_x]*powf(theta_scale, p);
|
|
|
}
|
|
}
|
|
|
else if (sector >= sections.v[0] && sector < sec_w) {
|
|
else if (sector >= sections.v[0] && sector < sec_w) {
|
|
|
const int p = sector - sections.v[0];
|
|
const int p = sector - sections.v[0];
|
|
|
- theta_base = pos[i2 + ne2]*powf(theta_scale, p);
|
|
|
|
|
|
|
+ theta_base = pos[channel_x + ne2]*powf(theta_scale, p);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
|
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
|
@@ -203,19 +223,20 @@ static __global__ void rope_vision(
|
|
|
float cos_theta;
|
|
float cos_theta;
|
|
|
float sin_theta;
|
|
float sin_theta;
|
|
|
|
|
|
|
|
- rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
|
|
|
|
|
|
+ rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
|
|
|
|
|
|
|
|
- const float x0 = x[i + 0];
|
|
|
|
|
- const float x1 = x[i + n_dims];
|
|
|
|
|
|
|
+ const float x0 = x[ix + 0];
|
|
|
|
|
+ const float x1 = x[ix + n_dims];
|
|
|
|
|
|
|
|
- dst[i + 0] = x0*cos_theta - x1*sin_theta;
|
|
|
|
|
- dst[i + n_dims] = x0*sin_theta + x1*cos_theta;
|
|
|
|
|
|
|
+ dst[idst + 0] = x0*cos_theta - x1*sin_theta;
|
|
|
|
|
+ dst[idst + n_dims] = x0*sin_theta + x1*cos_theta;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-template<typename T>
|
|
|
|
|
|
|
+template<bool forward, typename T>
|
|
|
static void rope_norm_cuda(
|
|
static void rope_norm_cuda(
|
|
|
- const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
|
|
|
|
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
|
|
|
|
|
|
+ const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
|
|
|
|
|
+ const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
|
|
|
|
+ const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, cudaStream_t stream) {
|
|
|
GGML_ASSERT(ne0 % 2 == 0);
|
|
GGML_ASSERT(ne0 % 2 == 0);
|
|
|
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
|
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
|
|
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
|
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
|
@@ -224,22 +245,21 @@ static void rope_norm_cuda(
|
|
|
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
|
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
|
|
|
|
|
|
|
if (freq_factors == nullptr) {
|
|
if (freq_factors == nullptr) {
|
|
|
- rope_norm<T, false><<<block_nums, block_dims, 0, stream>>>(
|
|
|
|
|
- x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
|
|
|
|
- theta_scale, freq_factors
|
|
|
|
|
- );
|
|
|
|
|
|
|
+ rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>(
|
|
|
|
|
+ x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
|
|
|
|
+ attn_factor, corr_dims, theta_scale, freq_factors);
|
|
|
} else {
|
|
} else {
|
|
|
- rope_norm<T, true><<<block_nums, block_dims, 0, stream>>>(
|
|
|
|
|
- x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
|
|
|
|
- theta_scale, freq_factors
|
|
|
|
|
- );
|
|
|
|
|
|
|
+ rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>(
|
|
|
|
|
+ x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
|
|
|
|
+ attn_factor, corr_dims, theta_scale, freq_factors);
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-template<typename T>
|
|
|
|
|
|
|
+template<bool forward, typename T>
|
|
|
static void rope_neox_cuda(
|
|
static void rope_neox_cuda(
|
|
|
- const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
|
|
|
|
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
|
|
|
|
|
|
+ const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
|
|
|
|
|
+ const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
|
|
|
|
+ const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, cudaStream_t stream) {
|
|
|
GGML_ASSERT(ne0 % 2 == 0);
|
|
GGML_ASSERT(ne0 % 2 == 0);
|
|
|
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
|
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
|
|
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
|
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
|
@@ -248,22 +268,21 @@ static void rope_neox_cuda(
|
|
|
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
|
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
|
|
|
|
|
|
|
if (freq_factors == nullptr) {
|
|
if (freq_factors == nullptr) {
|
|
|
- rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
|
|
|
|
|
- x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
|
|
|
|
- theta_scale, freq_factors
|
|
|
|
|
- );
|
|
|
|
|
|
|
+ rope_neox<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
|
|
|
|
|
+ x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
|
|
|
|
+ attn_factor, corr_dims, theta_scale, freq_factors);
|
|
|
} else {
|
|
} else {
|
|
|
- rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
|
|
|
|
|
- x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
|
|
|
|
- theta_scale, freq_factors
|
|
|
|
|
- );
|
|
|
|
|
|
|
+ rope_neox<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
|
|
|
|
|
+ x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
|
|
|
|
+ attn_factor, corr_dims, theta_scale, freq_factors);
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-template<typename T>
|
|
|
|
|
|
|
+template<bool forward, typename T>
|
|
|
static void rope_multi_cuda(
|
|
static void rope_multi_cuda(
|
|
|
- const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
|
|
|
|
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
|
|
|
|
|
|
|
+ const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
|
|
|
|
|
+ const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
|
|
|
|
+ const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, const mrope_sections sections, cudaStream_t stream) {
|
|
|
GGML_ASSERT(ne0 % 2 == 0);
|
|
GGML_ASSERT(ne0 % 2 == 0);
|
|
|
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
|
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
|
|
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
|
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
|
@@ -272,22 +291,21 @@ static void rope_multi_cuda(
|
|
|
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
|
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
|
|
|
|
|
|
|
if (freq_factors == nullptr) {
|
|
if (freq_factors == nullptr) {
|
|
|
- rope_multi<T, false><<<block_nums, block_dims, 0, stream>>>(
|
|
|
|
|
- x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
|
|
|
|
- theta_scale, freq_factors, sections
|
|
|
|
|
- );
|
|
|
|
|
|
|
+ rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
|
|
|
|
|
+ x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
|
|
|
|
+ attn_factor, corr_dims, theta_scale, freq_factors, sections);
|
|
|
} else {
|
|
} else {
|
|
|
- rope_multi<T, true><<<block_nums, block_dims, 0, stream>>>(
|
|
|
|
|
- x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
|
|
|
|
- theta_scale, freq_factors, sections
|
|
|
|
|
- );
|
|
|
|
|
|
|
+ rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
|
|
|
|
|
+ x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
|
|
|
|
+ attn_factor, corr_dims, theta_scale, freq_factors, sections);
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-template<typename T>
|
|
|
|
|
|
|
+template<bool forward, typename T>
|
|
|
static void rope_vision_cuda(
|
|
static void rope_vision_cuda(
|
|
|
- const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
|
|
|
|
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
|
|
|
|
|
|
|
+ const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
|
|
|
|
|
+ const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
|
|
|
|
+ const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, const mrope_sections sections, cudaStream_t stream) {
|
|
|
GGML_ASSERT(ne0 % 2 == 0);
|
|
GGML_ASSERT(ne0 % 2 == 0);
|
|
|
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
|
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
|
|
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
|
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
|
@@ -298,80 +316,18 @@ static void rope_vision_cuda(
|
|
|
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
|
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
|
|
|
|
|
|
|
if (freq_factors == nullptr) {
|
|
if (freq_factors == nullptr) {
|
|
|
- rope_vision<T, false><<<block_nums, block_dims, 0, stream>>>(
|
|
|
|
|
- x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
|
|
|
|
- theta_scale, freq_factors, sections
|
|
|
|
|
- );
|
|
|
|
|
|
|
+ rope_vision<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
|
|
|
|
|
+ x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
|
|
|
|
+ attn_factor, corr_dims, theta_scale, freq_factors, sections);
|
|
|
} else {
|
|
} else {
|
|
|
- rope_vision<T, true><<<block_nums, block_dims, 0, stream>>>(
|
|
|
|
|
- x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
|
|
|
|
- theta_scale, freq_factors, sections
|
|
|
|
|
- );
|
|
|
|
|
|
|
+ rope_vision<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
|
|
|
|
|
+ x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
|
|
|
|
+ attn_factor, corr_dims, theta_scale, freq_factors, sections);
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-static void rope_norm_cuda_f16(
|
|
|
|
|
- const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
|
|
|
|
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
|
|
|
|
-
|
|
|
|
|
- rope_norm_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-static void rope_norm_cuda_f32(
|
|
|
|
|
- const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
|
|
|
|
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
|
|
|
|
-
|
|
|
|
|
- rope_norm_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-static void rope_neox_cuda_f16(
|
|
|
|
|
- const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
|
|
|
|
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
|
|
|
|
-
|
|
|
|
|
- rope_neox_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-static void rope_neox_cuda_f32(
|
|
|
|
|
- const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
|
|
|
|
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
|
|
|
|
|
-) {
|
|
|
|
|
-
|
|
|
|
|
- rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-static void rope_multi_cuda_f16(
|
|
|
|
|
- const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
|
|
|
|
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
|
|
|
|
|
-) {
|
|
|
|
|
-
|
|
|
|
|
- rope_multi_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-static void rope_multi_cuda_f32(
|
|
|
|
|
- const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
|
|
|
|
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
|
|
|
|
|
-) {
|
|
|
|
|
-
|
|
|
|
|
- rope_multi_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-static void rope_vision_cuda_f16(
|
|
|
|
|
- const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
|
|
|
|
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
|
|
|
|
|
-) {
|
|
|
|
|
-
|
|
|
|
|
- rope_vision_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-static void rope_vision_cuda_f32(
|
|
|
|
|
- const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
|
|
|
|
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
|
|
|
|
|
-) {
|
|
|
|
|
-
|
|
|
|
|
- rope_vision_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
|
|
|
+template <bool forward>
|
|
|
|
|
+void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
|
const ggml_tensor * src1 = dst->src[1];
|
|
const ggml_tensor * src1 = dst->src[1];
|
|
|
const ggml_tensor * src2 = dst->src[2];
|
|
const ggml_tensor * src2 = dst->src[2];
|
|
@@ -382,7 +338,6 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
float * dst_d = (float *)dst->data;
|
|
float * dst_d = (float *)dst->data;
|
|
|
cudaStream_t stream = ctx.stream();
|
|
cudaStream_t stream = ctx.stream();
|
|
|
|
|
|
|
|
- GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
|
GGML_ASSERT(src0->type == dst->type);
|
|
GGML_ASSERT(src0->type == dst->type);
|
|
@@ -392,6 +347,9 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
const int64_t ne02 = src0->ne[2]; // num heads
|
|
const int64_t ne02 = src0->ne[2]; // num heads
|
|
|
const int64_t nr = ggml_nrows(src0);
|
|
const int64_t nr = ggml_nrows(src0);
|
|
|
|
|
|
|
|
|
|
+ const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
|
|
|
|
|
+ const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
|
|
|
|
|
+
|
|
|
//const int n_past = ((int32_t *) dst->op_params)[0];
|
|
//const int n_past = ((int32_t *) dst->op_params)[0];
|
|
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
|
const int mode = ((int32_t *) dst->op_params)[2];
|
|
const int mode = ((int32_t *) dst->op_params)[2];
|
|
@@ -440,59 +398,59 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
// compute
|
|
// compute
|
|
|
if (is_neox) {
|
|
if (is_neox) {
|
|
|
if (src0->type == GGML_TYPE_F32) {
|
|
if (src0->type == GGML_TYPE_F32) {
|
|
|
- rope_neox_cuda_f32(
|
|
|
|
|
- (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
|
|
|
|
- attn_factor, corr_dims, freq_factors, stream
|
|
|
|
|
- );
|
|
|
|
|
|
|
+ rope_neox_cuda<forward>(
|
|
|
|
|
+ (const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
|
|
|
|
|
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
|
|
} else if (src0->type == GGML_TYPE_F16) {
|
|
} else if (src0->type == GGML_TYPE_F16) {
|
|
|
- rope_neox_cuda_f16(
|
|
|
|
|
- (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
|
|
|
|
- attn_factor, corr_dims, freq_factors, stream
|
|
|
|
|
- );
|
|
|
|
|
|
|
+ rope_neox_cuda<forward>(
|
|
|
|
|
+ (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
|
|
|
|
|
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
|
|
} else {
|
|
} else {
|
|
|
GGML_ABORT("fatal error");
|
|
GGML_ABORT("fatal error");
|
|
|
}
|
|
}
|
|
|
} else if (is_mrope && !is_vision) {
|
|
} else if (is_mrope && !is_vision) {
|
|
|
if (src0->type == GGML_TYPE_F32) {
|
|
if (src0->type == GGML_TYPE_F32) {
|
|
|
- rope_multi_cuda_f32(
|
|
|
|
|
- (const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
|
|
|
|
- attn_factor, corr_dims, freq_factors, sections, stream
|
|
|
|
|
- );
|
|
|
|
|
|
|
+ rope_multi_cuda<forward>(
|
|
|
|
|
+ (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
|
|
|
|
|
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
|
|
} else if (src0->type == GGML_TYPE_F16) {
|
|
} else if (src0->type == GGML_TYPE_F16) {
|
|
|
- rope_multi_cuda_f16(
|
|
|
|
|
- (const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
|
|
|
|
- attn_factor, corr_dims, freq_factors, sections, stream
|
|
|
|
|
- );
|
|
|
|
|
|
|
+ rope_multi_cuda<forward>(
|
|
|
|
|
+ (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
|
|
|
|
|
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
|
|
} else {
|
|
} else {
|
|
|
GGML_ABORT("fatal error");
|
|
GGML_ABORT("fatal error");
|
|
|
}
|
|
}
|
|
|
} else if (is_vision) {
|
|
} else if (is_vision) {
|
|
|
if (src0->type == GGML_TYPE_F32) {
|
|
if (src0->type == GGML_TYPE_F32) {
|
|
|
- rope_vision_cuda_f32(
|
|
|
|
|
- (const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
|
|
|
|
- attn_factor, corr_dims, freq_factors, sections, stream
|
|
|
|
|
- );
|
|
|
|
|
|
|
+ rope_vision_cuda<forward>(
|
|
|
|
|
+ (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
|
|
|
|
|
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
|
|
} else if (src0->type == GGML_TYPE_F16) {
|
|
} else if (src0->type == GGML_TYPE_F16) {
|
|
|
- rope_vision_cuda_f16(
|
|
|
|
|
- (const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
|
|
|
|
- attn_factor, corr_dims, freq_factors, sections, stream
|
|
|
|
|
- );
|
|
|
|
|
|
|
+ rope_vision_cuda<forward>(
|
|
|
|
|
+ (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
|
|
|
|
|
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
|
|
} else {
|
|
} else {
|
|
|
GGML_ABORT("fatal error");
|
|
GGML_ABORT("fatal error");
|
|
|
}
|
|
}
|
|
|
} else {
|
|
} else {
|
|
|
if (src0->type == GGML_TYPE_F32) {
|
|
if (src0->type == GGML_TYPE_F32) {
|
|
|
- rope_norm_cuda_f32(
|
|
|
|
|
- (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
|
|
|
|
- attn_factor, corr_dims, freq_factors, stream
|
|
|
|
|
- );
|
|
|
|
|
|
|
+ rope_norm_cuda<forward>(
|
|
|
|
|
+ (const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
|
|
|
|
|
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
|
|
} else if (src0->type == GGML_TYPE_F16) {
|
|
} else if (src0->type == GGML_TYPE_F16) {
|
|
|
- rope_norm_cuda_f16(
|
|
|
|
|
- (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
|
|
|
|
- attn_factor, corr_dims, freq_factors, stream
|
|
|
|
|
- );
|
|
|
|
|
|
|
+ rope_norm_cuda<forward>(
|
|
|
|
|
+ (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
|
|
|
|
|
+ freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
|
|
} else {
|
|
} else {
|
|
|
GGML_ABORT("fatal error");
|
|
GGML_ABORT("fatal error");
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
|
+ ggml_cuda_op_rope_impl<true>(ctx, dst);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
|
+ ggml_cuda_op_rope_impl<false>(ctx, dst);
|
|
|
|
|
+}
|