|
@@ -3184,8 +3184,8 @@ kernel void kernel_flash_attn_ext(
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
|
|
{
|
|
{
|
|
|
- half S[Q] = { [0 ... Q-1] = 0.0f };
|
|
|
|
|
- half M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
|
|
|
|
|
|
|
+ float S[Q] = { [0 ... Q-1] = 0.0f };
|
|
|
|
|
+ float M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
|
|
|
|
|
|
|
|
// thread indices inside the simdgroup
|
|
// thread indices inside the simdgroup
|
|
|
// TODO: see if we can utilize quad-group functions for better performance
|
|
// TODO: see if we can utilize quad-group functions for better performance
|
|
@@ -3202,13 +3202,13 @@ kernel void kernel_flash_attn_ext(
|
|
|
|
|
|
|
|
const bool has_mask = mask != q;
|
|
const bool has_mask = mask != q;
|
|
|
|
|
|
|
|
- half slope = 1.0f;
|
|
|
|
|
|
|
+ float slope = 1.0f;
|
|
|
|
|
|
|
|
// ALiBi
|
|
// ALiBi
|
|
|
if (args.max_bias > 0.0f) {
|
|
if (args.max_bias > 0.0f) {
|
|
|
const short h = iq2;
|
|
const short h = iq2;
|
|
|
|
|
|
|
|
- const half base = h < args.n_head_log2 ? args.m0 : args.m1;
|
|
|
|
|
|
|
+ const float base = h < args.n_head_log2 ? args.m0 : args.m1;
|
|
|
const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
|
const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
|
|
|
|
|
|
|
slope = pow(base, exph);
|
|
slope = pow(base, exph);
|
|
@@ -3224,14 +3224,14 @@ kernel void kernel_flash_attn_ext(
|
|
|
|
|
|
|
|
if (has_mask) {
|
|
if (has_mask) {
|
|
|
// used to detect blocks full of -INF
|
|
// used to detect blocks full of -INF
|
|
|
- half smax = -INFINITY;
|
|
|
|
|
|
|
+ float smax = -INFINITY;
|
|
|
|
|
|
|
|
// load the mask in shared memory
|
|
// load the mask in shared memory
|
|
|
#pragma unroll(Q)
|
|
#pragma unroll(Q)
|
|
|
for (short j = 0; j < Q; ++j) {
|
|
for (short j = 0; j < Q; ++j) {
|
|
|
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
|
|
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
|
|
|
|
|
|
|
|
- const half m = pm[ic + tiisg];
|
|
|
|
|
|
|
+ const float m = pm[ic + tiisg];
|
|
|
|
|
|
|
|
ss[j*TS + C + tiisg] = m;
|
|
ss[j*TS + C + tiisg] = m;
|
|
|
smax = max(smax, m);
|
|
smax = max(smax, m);
|
|
@@ -3327,10 +3327,10 @@ kernel void kernel_flash_attn_ext(
|
|
|
// online softmax
|
|
// online softmax
|
|
|
{
|
|
{
|
|
|
for (ushort j = 0; j < Q; ++j) {
|
|
for (ushort j = 0; j < Q; ++j) {
|
|
|
- const half m = M[j];
|
|
|
|
|
|
|
+ const float m = M[j];
|
|
|
|
|
|
|
|
// scale and apply the logitcap / mask
|
|
// scale and apply the logitcap / mask
|
|
|
- half s = ss[j*TS + tiisg]*args.scale;
|
|
|
|
|
|
|
+ float s = ss[j*TS + tiisg]*args.scale;
|
|
|
|
|
|
|
|
if (args.logit_softcap != 0.0f) {
|
|
if (args.logit_softcap != 0.0f) {
|
|
|
s = args.logit_softcap*precise::tanh(s);
|
|
s = args.logit_softcap*precise::tanh(s);
|
|
@@ -3341,8 +3341,8 @@ kernel void kernel_flash_attn_ext(
|
|
|
|
|
|
|
|
M[j] = simd_max(max(M[j], s));
|
|
M[j] = simd_max(max(M[j], s));
|
|
|
|
|
|
|
|
- const half ms = exp(m - M[j]);
|
|
|
|
|
- const half vs = exp(s - M[j]);
|
|
|
|
|
|
|
+ const float ms = exp(m - M[j]);
|
|
|
|
|
+ const float vs = exp(s - M[j]);
|
|
|
|
|
|
|
|
S[j] = S[j]*ms + simd_sum(vs);
|
|
S[j] = S[j]*ms + simd_sum(vs);
|
|
|
|
|
|
|
@@ -3444,8 +3444,8 @@ kernel void kernel_flash_attn_ext(
|
|
|
|
|
|
|
|
// reduce the warps sequentially
|
|
// reduce the warps sequentially
|
|
|
for (ushort sg = 1; sg < nsg; ++sg) {
|
|
for (ushort sg = 1; sg < nsg; ++sg) {
|
|
|
- half S = { 0.0f };
|
|
|
|
|
- half M = { -__FLT16_MAX__/2 };
|
|
|
|
|
|
|
+ float S = { 0.0f };
|
|
|
|
|
+ float M = { -__FLT16_MAX__/2 };
|
|
|
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
|
@@ -3461,16 +3461,16 @@ kernel void kernel_flash_attn_ext(
|
|
|
// the first simdgroup accumulates the results from the other simdgroups
|
|
// the first simdgroup accumulates the results from the other simdgroups
|
|
|
if (sgitg == 0) {
|
|
if (sgitg == 0) {
|
|
|
for (short j = 0; j < Q; ++j) {
|
|
for (short j = 0; j < Q; ++j) {
|
|
|
- const half S0 = ss[j*TS + 0];
|
|
|
|
|
- const half S1 = ss[j*TS + sg*SH + 0];
|
|
|
|
|
|
|
+ const float S0 = ss[j*TS + 0];
|
|
|
|
|
+ const float S1 = ss[j*TS + sg*SH + 0];
|
|
|
|
|
|
|
|
- const half M0 = ss[j*TS + 1];
|
|
|
|
|
- const half M1 = ss[j*TS + sg*SH + 1];
|
|
|
|
|
|
|
+ const float M0 = ss[j*TS + 1];
|
|
|
|
|
+ const float M1 = ss[j*TS + sg*SH + 1];
|
|
|
|
|
|
|
|
M = max(M0, M1);
|
|
M = max(M0, M1);
|
|
|
|
|
|
|
|
- const half ms0 = exp(M0 - M);
|
|
|
|
|
- const half ms1 = exp(M1 - M);
|
|
|
|
|
|
|
+ const float ms0 = exp(M0 - M);
|
|
|
|
|
+ const float ms1 = exp(M1 - M);
|
|
|
|
|
|
|
|
S = S0*ms0 + S1*ms1;
|
|
S = S0*ms0 + S1*ms1;
|
|
|
|
|
|
|
@@ -3646,16 +3646,16 @@ kernel void kernel_flash_attn_ext_vec(
|
|
|
constexpr short DV4 = DV/4;
|
|
constexpr short DV4 = DV/4;
|
|
|
constexpr short NW = N_SIMDWIDTH;
|
|
constexpr short NW = N_SIMDWIDTH;
|
|
|
constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
|
|
constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
|
|
|
- constexpr short SH = 2*C; // shared memory per simdgroup
|
|
|
|
|
|
|
+ constexpr short SH = 4*C; // shared memory per simdgroup
|
|
|
|
|
|
|
|
const short T = DK + nsg*SH; // shared memory size per query in (half)
|
|
const short T = DK + nsg*SH; // shared memory size per query in (half)
|
|
|
|
|
|
|
|
- //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
|
|
|
|
|
- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
|
|
|
|
|
- threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
|
|
|
|
|
- threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
|
|
|
|
|
- threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*DK); // scratch buffer for mask
|
|
|
|
|
- threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
|
|
|
|
|
|
|
+ //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
|
|
|
|
|
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
|
|
|
|
|
+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
|
|
|
|
|
+ threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
|
|
|
|
|
+ threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
|
|
|
|
|
+ threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
|
|
|
|
|
|
|
|
// store the result for all queries in local memory (the O matrix from the paper)
|
|
// store the result for all queries in local memory (the O matrix from the paper)
|
|
|
o4_t lo[DV4/NL];
|
|
o4_t lo[DV4/NL];
|
|
@@ -3684,8 +3684,8 @@ kernel void kernel_flash_attn_ext_vec(
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
|
|
{
|
|
{
|
|
|
- half S = 0.0f;
|
|
|
|
|
- half M = -__FLT16_MAX__/2;
|
|
|
|
|
|
|
+ float S = 0.0f;
|
|
|
|
|
+ float M = -__FLT16_MAX__/2;
|
|
|
|
|
|
|
|
// thread indices inside the simdgroup
|
|
// thread indices inside the simdgroup
|
|
|
const short tx = tiisg%NL;
|
|
const short tx = tiisg%NL;
|
|
@@ -3703,13 +3703,13 @@ kernel void kernel_flash_attn_ext_vec(
|
|
|
// pointer to the mask
|
|
// pointer to the mask
|
|
|
device const half * pm = (device const half *) (mask + iq1*args.nb31);
|
|
device const half * pm = (device const half *) (mask + iq1*args.nb31);
|
|
|
|
|
|
|
|
- half slope = 1.0f;
|
|
|
|
|
|
|
+ float slope = 1.0f;
|
|
|
|
|
|
|
|
// ALiBi
|
|
// ALiBi
|
|
|
if (args.max_bias > 0.0f) {
|
|
if (args.max_bias > 0.0f) {
|
|
|
const short h = iq2;
|
|
const short h = iq2;
|
|
|
|
|
|
|
|
- const half base = h < args.n_head_log2 ? args.m0 : args.m1;
|
|
|
|
|
|
|
+ const float base = h < args.n_head_log2 ? args.m0 : args.m1;
|
|
|
const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
|
const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
|
|
|
|
|
|
|
slope = pow(base, exph);
|
|
slope = pow(base, exph);
|
|
@@ -3799,13 +3799,13 @@ kernel void kernel_flash_attn_ext_vec(
|
|
|
|
|
|
|
|
// online softmax
|
|
// online softmax
|
|
|
{
|
|
{
|
|
|
- const half m = M;
|
|
|
|
|
- const half s = ss[tiisg];
|
|
|
|
|
|
|
+ const float m = M;
|
|
|
|
|
+ const float s = ss[tiisg];
|
|
|
|
|
|
|
|
M = simd_max(max(M, s));
|
|
M = simd_max(max(M, s));
|
|
|
|
|
|
|
|
- const half ms = exp(m - M);
|
|
|
|
|
- const half vs = exp(s - M);
|
|
|
|
|
|
|
+ const float ms = exp(m - M);
|
|
|
|
|
+ const float vs = exp(s - M);
|
|
|
|
|
|
|
|
S = S*ms + simd_sum(vs);
|
|
S = S*ms + simd_sum(vs);
|
|
|
|
|
|
|
@@ -3836,7 +3836,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|
|
v4_t mv;
|
|
v4_t mv;
|
|
|
deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);
|
|
deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);
|
|
|
|
|
|
|
|
- lo[ii/NL] += mv*ms;
|
|
|
|
|
|
|
+ lo[ii/NL] += o4_t(float4(mv)*float4(ms));
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
@@ -3907,18 +3907,18 @@ kernel void kernel_flash_attn_ext_vec(
|
|
|
// parallel reduce
|
|
// parallel reduce
|
|
|
for (short r = nsg/2; r > 0; r >>= 1) {
|
|
for (short r = nsg/2; r > 0; r >>= 1) {
|
|
|
if (sgitg < r) {
|
|
if (sgitg < r) {
|
|
|
- const half S0 = ss[ 0];
|
|
|
|
|
- const half S1 = ss[r*SH + 0];
|
|
|
|
|
|
|
+ const float S0 = ss[ 0];
|
|
|
|
|
+ const float S1 = ss[r*(SH/2) + 0];
|
|
|
|
|
|
|
|
- const half M0 = ss[ 1];
|
|
|
|
|
- const half M1 = ss[r*SH + 1];
|
|
|
|
|
|
|
+ const float M0 = ss[ 1];
|
|
|
|
|
+ const float M1 = ss[r*(SH/2) + 1];
|
|
|
|
|
|
|
|
- const half M = max(M0, M1);
|
|
|
|
|
|
|
+ const float M = max(M0, M1);
|
|
|
|
|
|
|
|
- const half ms0 = exp(M0 - M);
|
|
|
|
|
- const half ms1 = exp(M1 - M);
|
|
|
|
|
|
|
+ const float ms0 = exp(M0 - M);
|
|
|
|
|
+ const float ms1 = exp(M1 - M);
|
|
|
|
|
|
|
|
- const half S = S0*ms0 + S1*ms1;
|
|
|
|
|
|
|
+ const float S = S0*ms0 + S1*ms1;
|
|
|
|
|
|
|
|
if (tiisg == 0) {
|
|
if (tiisg == 0) {
|
|
|
ss[0] = S;
|
|
ss[0] = S;
|
|
@@ -3950,11 +3950,11 @@ kernel void kernel_flash_attn_ext_vec(
|
|
|
// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
|
|
// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
|
|
|
//
|
|
//
|
|
|
#define FA_TYPES \
|
|
#define FA_TYPES \
|
|
|
- half4, \
|
|
|
|
|
- half4, \
|
|
|
|
|
- half4, \
|
|
|
|
|
- float, \
|
|
|
|
|
- half, half4, \
|
|
|
|
|
|
|
+ half4, \
|
|
|
|
|
+ half4, \
|
|
|
|
|
+ half4, \
|
|
|
|
|
+ float, \
|
|
|
|
|
+ float, float4, \
|
|
|
half4
|
|
half4
|
|
|
|
|
|
|
|
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
|
|
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
|