|
@@ -2217,7 +2217,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|
|
|
|
|
|
|
// ALiBi
|
|
// ALiBi
|
|
|
if (max_bias > 0.0f) {
|
|
if (max_bias > 0.0f) {
|
|
|
- const short h = iq2;
|
|
|
|
|
|
|
+ const uint32_t h = iq2;
|
|
|
|
|
|
|
|
const float base = h < n_head_log2 ? m0 : m1;
|
|
const float base = h < n_head_log2 ? m0 : m1;
|
|
|
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
|
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
|
@@ -2473,7 +2473,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
|
|
|
|
|
|
// ALiBi
|
|
// ALiBi
|
|
|
if (max_bias > 0.0f) {
|
|
if (max_bias > 0.0f) {
|
|
|
- const short h = iq2;
|
|
|
|
|
|
|
+ const uint32_t h = iq2;
|
|
|
|
|
|
|
|
const float base = h < n_head_log2 ? m0 : m1;
|
|
const float base = h < n_head_log2 ? m0 : m1;
|
|
|
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
|
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|