Просмотр исходного кода

CUDA: fix allignment on register spill for FA (#18815)

Johannes Gäßler 1 неделя назад
Родитель
Сommit
5c662d21a3

+ 2 - 2
ggml/src/ggml-cuda/fattn-common.cuh

@@ -59,7 +59,7 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
 
 #pragma unroll
     for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
-        half2 tmp[cpy_ne];
+        __align__(16) half2 tmp[cpy_ne];
         ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
 #pragma unroll
         for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
@@ -309,7 +309,7 @@ static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict_
         ggml_cuda_memcpy_1<ne*sizeof(half)>(dst, (const half *) vx + i0);
     } else if constexpr (std::is_same_v<T, float>) {
         static_assert(ne % 2 == 0, "bad ne");
-        half2 tmp[ne/2];
+        __align__(16) half2 tmp[ne/2];
         ggml_cuda_memcpy_1<ne*sizeof(half)>(tmp, (const half *) vx + i0);
         float2 * dst_f2 = (float2 *) dst;
 #pragma unroll

+ 21 - 21
ggml/src/ggml-cuda/fattn-tile.cuh

@@ -343,7 +343,7 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile(
                 for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) {
                     const int j = j0*cpy_ne + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*cpy_ne;
 
-                    const half2 zero[cpy_ne] = {{0.0f, 0.0f}};
+                    const __align__(16) half2 zero[cpy_ne] = {{0.0f, 0.0f}};
                     ggml_cuda_memcpy_1<cpy_nb>(
                         tile_KV + i*(J/2 + J_padding) + j,
                         !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
@@ -394,11 +394,11 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile(
                     const int j = j0*(cpy_ne/2) + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*(cpy_ne/2);
 
                     const half2 zero[cpy_ne/2] = {{0.0f, 0.0f}};
-                    half2 tmp_h2[cpy_ne/2];
+                    __align__(16) half2 tmp_h2[cpy_ne/2];
                     ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
                         tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
 
-                    float2 tmp_f2[cpy_ne/2];
+                    __align__(16) float2 tmp_f2[cpy_ne/2];
 #pragma unroll
                     for (int l = 0; l < cpy_ne/2; ++l) {
                         tmp_f2[l] = __half22float2(tmp_h2[l]);
@@ -445,14 +445,14 @@ static __device__ __forceinline__ void flash_attn_tile_iter_KQ(
     static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K");
 #pragma unroll
     for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K/2; k_KQ_1 += cpy_ne) {
-        half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne];
-        half2 Q_k[cpw][cpy_ne];
+        __align__(16) half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne];
+        __align__(16) half2 Q_k[cpw][cpy_ne];
 #else
     static_assert(nbatch_K % cpy_ne == 0, "bad nbatch_K");
 #pragma unroll
     for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K; k_KQ_1 += cpy_ne) {
-        float K_k[nbatch_fa/(np*warp_size)][cpy_ne];
-        float Q_k[cpw][cpy_ne];
+        __align__(16) float K_k[nbatch_fa/(np*warp_size)][cpy_ne];
+        __align__(16) float Q_k[cpw][cpy_ne];
 #endif // FAST_FP16_AVAILABLE
 
 #pragma unroll
@@ -602,9 +602,9 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
 #pragma unroll
     for (int jc0 = 0; jc0 < cpw; jc0 += KQ_cs) {
 #ifdef FAST_FP16_AVAILABLE
-        half  tmp[nbatch_fa/(np*warp_size)][KQ_cs];
+        __align__(16) half  tmp[nbatch_fa/(np*warp_size)][KQ_cs];
 #else
-        float tmp[nbatch_fa/(np*warp_size)][KQ_cs];
+        __align__(16) float tmp[nbatch_fa/(np*warp_size)][KQ_cs];
 #endif // FAST_FP16_AVAILABLE
 
 #pragma unroll
@@ -664,8 +664,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
 #ifdef FAST_FP16_AVAILABLE
 #pragma unroll
         for (int k1 = 0; k1 < nbatch_V; k1 += np) {
-            half2 V_k[(DVp/2)/warp_size];
-            half2 KQ_k[cpw];
+            __align__(16) half2 V_k[(DVp/2)/warp_size];
+            __align__(16) half2 KQ_k[cpw];
 
             constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
 #pragma unroll
@@ -676,7 +676,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
             for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) {
                 const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs);
 
-                half tmp[KQ_cs];
+                __align__(16) half tmp[KQ_cs];
                 ggml_cuda_memcpy_1<KQ_cs*sizeof(half)>(
                     &tmp, KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs);
 #pragma unroll
@@ -696,8 +696,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
 #else
 #pragma unroll
         for (int k1 = 0; k1 < nbatch_V; k1 += np) {
-            float2 V_k[(DVp/2)/warp_size];
-            float  KQ_k[cpw];
+            __align__(16) float2 V_k[(DVp/2)/warp_size];
+            __align__(16) float  KQ_k[cpw];
 
             constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
 #pragma unroll
@@ -821,12 +821,12 @@ static __global__ void flash_attn_tile(
     __shared__ half2 Q_tmp[ncols * DKQ/2];
     __shared__ half2 KV_tmp[nbatch_fa * (nbatch_K/2 + cpy_ne) + DVp-DV];
     __shared__ half  KQ[ncols * nbatch_fa];
-    half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
+    __align__(16) half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
 #else
     __shared__ float Q_tmp[ncols * DKQ];
     __shared__ float KV_tmp[nbatch_fa * (nbatch_K + cpy_ne) + DVp-DV];
     __shared__ float KQ[ncols * nbatch_fa];
-    float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
+    __align__(16) float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
 #endif // FAST_FP16_AVAILABLE
 
     float KQ_max[cpw];
@@ -849,7 +849,7 @@ static __global__ void flash_attn_tile(
 #pragma unroll
         for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) {
             if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) {
-                float tmp_f[cpy_ne_D] = {0.0f};
+                __align__(16) float tmp_f[cpy_ne_D] = {0.0f};
                 ggml_cuda_memcpy_1<sizeof(tmp_f)>
                     (tmp_f, &Q_f[c*(nb02/sizeof(float)) + fastmodulo(col_Q_0 + j, ne01)*(nb01/sizeof(float))
                                  + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]);
@@ -860,7 +860,7 @@ static __global__ void flash_attn_tile(
                 }
 
 #ifdef FAST_FP16_AVAILABLE
-                half2 tmp_h2[cpy_ne_D/2];
+                __align__(16) half2 tmp_h2[cpy_ne_D/2];
 #pragma unroll
                 for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
                     tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
@@ -959,7 +959,7 @@ static __global__ void flash_attn_tile(
             constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size;
 #pragma unroll
             for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
-                half2 tmp[cpy_ne_D];
+                __align__(16) half2 tmp[cpy_ne_D];
                 ggml_cuda_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(threadIdx.y + ip)*(DVp/2) + i0 + threadIdx.x*cpy_ne_D]);
 #pragma unroll
                 for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
@@ -970,7 +970,7 @@ static __global__ void flash_attn_tile(
             constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
 #pragma unroll
             for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
-                float tmp[cpy_ne_D];
+                __align__(16) float tmp[cpy_ne_D];
                 ggml_cuda_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(threadIdx.y + ip)*DVp + i0 + threadIdx.x*cpy_ne_D]);
 #pragma unroll
                 for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
@@ -1033,7 +1033,7 @@ static __global__ void flash_attn_tile(
         constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
 #pragma unroll
         for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
-            float2 tmp[cpy_ne_D];
+            __align__(16) float2 tmp[cpy_ne_D];
 #pragma unroll
             for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
                 tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]);

+ 2 - 2
ggml/src/ggml-cuda/fattn-vec.cuh

@@ -132,7 +132,7 @@ static __global__ void flash_attn_ext_vec(
 #ifdef V_DOT2_F32_F16_AVAILABLE
     half2  Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely.
 #else
-    float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
+    __align__(16) float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
 #endif // V_DOT2_F32_F16_AVAILABLE
     int    Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
     float2  Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
@@ -200,7 +200,7 @@ static __global__ void flash_attn_ext_vec(
             for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
                 const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;
 
-                float2 tmp[cpy_ne] = {{0.0f, 0.0f}};
+                __align__(16) float2 tmp[cpy_ne] = {{0.0f, 0.0f}};
                 if (ncols == 1 || ic0 + j < int(ne01.z)) {
                     ggml_cuda_memcpy_1<cpy_nb>(tmp,            &Q_j[i]);
                     ggml_cuda_memcpy_1<cpy_nb>(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]);