|
|
@@ -152,14 +152,17 @@ void main() {
|
|
|
uint32_t d = (idx + tid) % (HSK / 4);
|
|
|
uint32_t c = (idx + tid) / (HSK / 4);
|
|
|
if (c < Bc && d < HSK / 4) {
|
|
|
+ f16vec4 K_Tf = f16vec4(0);
|
|
|
+ if (!KV_bounds_check || j * Bc + c < KV) {
|
|
|
#if BLOCK_SIZE > 1
|
|
|
- uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
|
|
|
- uint ib = coord / BLOCK_SIZE;
|
|
|
- uint iqs = (coord % BLOCK_SIZE);
|
|
|
- f16vec4 K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
|
|
|
+ uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
|
|
|
+ uint ib = coord / BLOCK_SIZE;
|
|
|
+ uint iqs = (coord % BLOCK_SIZE);
|
|
|
+ K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
|
|
|
#else
|
|
|
- f16vec4 K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
|
|
|
+ K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
|
|
|
#endif
|
|
|
+ }
|
|
|
|
|
|
ksh[c * kshstride + d] = K_Tf;
|
|
|
}
|
|
|
@@ -202,7 +205,9 @@ void main() {
|
|
|
uint32_t c = (idx + tid) % Bc;
|
|
|
uint32_t r = (idx + tid) / Bc;
|
|
|
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
|
|
|
- sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
|
|
|
+ if (!KV_bounds_check || j * Bc + c < KV) {
|
|
|
+ sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
barrier();
|
|
|
@@ -210,8 +215,11 @@ void main() {
|
|
|
|
|
|
float eMf[rows_per_thread];
|
|
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
|
- float rowmaxf = sfsh[tile_row(r) + (0 * cols_per_iter + col_tid) * sfshstride];
|
|
|
+ float rowmaxf = NEG_FLT_MAX_OVER_2;
|
|
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
|
|
+ if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride]));
|
|
|
}
|
|
|
float Moldf = Mf[r];
|
|
|
@@ -233,6 +241,9 @@ void main() {
|
|
|
}
|
|
|
|
|
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
|
|
+ if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
float Pf[rows_per_thread];
|
|
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
|
Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
|