|
@@ -392,7 +392,8 @@ kernel void kernel_rope_multi_f32(
|
|
|
float attn_factor,
|
|
float attn_factor,
|
|
|
float beta_fast,
|
|
float beta_fast,
|
|
|
float beta_slow,
|
|
float beta_slow,
|
|
|
- int4 sections
|
|
|
|
|
|
|
+ int4 sections,
|
|
|
|
|
+ int is_imrope
|
|
|
) {
|
|
) {
|
|
|
src0 = (global void*)((global char*)src0 + offset0);
|
|
src0 = (global void*)((global char*)src0 + offset0);
|
|
|
src1 = (global int*)((global char*)src1 + offset1);
|
|
src1 = (global int*)((global char*)src1 + offset1);
|
|
@@ -419,17 +420,29 @@ kernel void kernel_rope_multi_f32(
|
|
|
const int sector = (i0 / 2) % sect_dims;
|
|
const int sector = (i0 / 2) % sect_dims;
|
|
|
float theta_base = 0.0f;
|
|
float theta_base = 0.0f;
|
|
|
|
|
|
|
|
- if (sector < sections.s0) {
|
|
|
|
|
- theta_base = pos[i2];
|
|
|
|
|
- }
|
|
|
|
|
- else if (sector >= sections.s0 && sector < sec_w) {
|
|
|
|
|
- theta_base = pos[i2 + ne2 * 1];
|
|
|
|
|
- }
|
|
|
|
|
- else if (sector >= sec_w && sector < sec_w + sections.s2) {
|
|
|
|
|
- theta_base = pos[i2 + ne2 * 2];
|
|
|
|
|
- }
|
|
|
|
|
- else if (sector >= sec_w + sections.s2) {
|
|
|
|
|
- theta_base = pos[i2 + ne2 * 3];
|
|
|
|
|
|
|
+ if (is_imrope) {
|
|
|
|
|
+ if (sector % 3 == 1 && sector < 3 * sections.s1) { // h
|
|
|
|
|
+ theta_base = (float) pos[i2 + ne02 * 1];
|
|
|
|
|
+ } else if (sector % 3 == 2 && sector < 3 * sections.s2) { // w
|
|
|
|
|
+ theta_base = (float) pos[i2 + ne02 * 2];
|
|
|
|
|
+ } else if (sector % 3 == 0 && sector < 3 * sections.s0) { // t
|
|
|
|
|
+ theta_base = (float) pos[i2 + ne02 * 0];
|
|
|
|
|
+ } else { // e
|
|
|
|
|
+ theta_base = (float) pos[i2 + ne02 * 3];
|
|
|
|
|
+ }
|
|
|
|
|
+ } else {
|
|
|
|
|
+ if (sector < sections.s0) {
|
|
|
|
|
+ theta_base = pos[i2];
|
|
|
|
|
+ }
|
|
|
|
|
+ else if (sector >= sections.s0 && sector < sec_w) {
|
|
|
|
|
+ theta_base = pos[i2 + ne2 * 1];
|
|
|
|
|
+ }
|
|
|
|
|
+ else if (sector >= sec_w && sector < sec_w + sections.s2) {
|
|
|
|
|
+ theta_base = pos[i2 + ne2 * 2];
|
|
|
|
|
+ }
|
|
|
|
|
+ else if (sector >= sec_w + sections.s2) {
|
|
|
|
|
+ theta_base = pos[i2 + ne2 * 3];
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
|
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
|
@@ -490,7 +503,8 @@ kernel void kernel_rope_multi_f16(
|
|
|
float attn_factor,
|
|
float attn_factor,
|
|
|
float beta_fast,
|
|
float beta_fast,
|
|
|
float beta_slow,
|
|
float beta_slow,
|
|
|
- int4 sections
|
|
|
|
|
|
|
+ int4 sections,
|
|
|
|
|
+ int is_imrope
|
|
|
) {
|
|
) {
|
|
|
src0 = (global void*)((global char*)src0 + offset0);
|
|
src0 = (global void*)((global char*)src0 + offset0);
|
|
|
src1 = (global int*)((global char*)src1 + offset1);
|
|
src1 = (global int*)((global char*)src1 + offset1);
|
|
@@ -517,17 +531,29 @@ kernel void kernel_rope_multi_f16(
|
|
|
const int sector = (i0 / 2) % sect_dims;
|
|
const int sector = (i0 / 2) % sect_dims;
|
|
|
float theta_base = 0.0f;
|
|
float theta_base = 0.0f;
|
|
|
|
|
|
|
|
- if (sector < sections.s0) {
|
|
|
|
|
- theta_base = pos[i2];
|
|
|
|
|
- }
|
|
|
|
|
- else if (sector >= sections.s0 && sector < sec_w) {
|
|
|
|
|
- theta_base = pos[i2 + ne2 * 1];
|
|
|
|
|
- }
|
|
|
|
|
- else if (sector >= sec_w && sector < sec_w + sections.s2) {
|
|
|
|
|
- theta_base = pos[i2 + ne2 * 2];
|
|
|
|
|
- }
|
|
|
|
|
- else if (sector >= sec_w + sections.s2) {
|
|
|
|
|
- theta_base = pos[i2 + ne2 * 3];
|
|
|
|
|
|
|
+ if (is_imrope) {
|
|
|
|
|
+ if (sector % 3 == 1 && sector < 3 * sections.s1) { // h
|
|
|
|
|
+ theta_base = (float) pos[i2 + ne02 * 1];
|
|
|
|
|
+ } else if (sector % 3 == 2 && sector < 3 * sections.s2) { // w
|
|
|
|
|
+ theta_base = (float) pos[i2 + ne02 * 2];
|
|
|
|
|
+ } else if (sector % 3 == 0 && sector < 3 * sections.s0) { // t
|
|
|
|
|
+ theta_base = (float) pos[i2 + ne02 * 0];
|
|
|
|
|
+ } else { // e
|
|
|
|
|
+ theta_base = (float) pos[i2 + ne02 * 3];
|
|
|
|
|
+ }
|
|
|
|
|
+ } else {
|
|
|
|
|
+ if (sector < sections.s0) {
|
|
|
|
|
+ theta_base = pos[i2];
|
|
|
|
|
+ }
|
|
|
|
|
+ else if (sector >= sections.s0 && sector < sec_w) {
|
|
|
|
|
+ theta_base = pos[i2 + ne2 * 1];
|
|
|
|
|
+ }
|
|
|
|
|
+ else if (sector >= sec_w && sector < sec_w + sections.s2) {
|
|
|
|
|
+ theta_base = pos[i2 + ne2 * 2];
|
|
|
|
|
+ }
|
|
|
|
|
+ else if (sector >= sec_w + sections.s2) {
|
|
|
|
|
+ theta_base = pos[i2 + ne2 * 3];
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
|
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|