Explorar el Código

HIP: Patch failed testcase in WMMA-MMQ kernels for RDNA 4 (#17502)

* patch failed test case MUL_MAT(type_a=q4_0,type_b=f32,m=576,n=512,k=576,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) for enabling WMMA on RDNA4

* Quick clean up on mma.cuh to add ggml_cuda_memcpy_1 back in for half2 and bfloat162
Jiacheng (Jason) Chen hace 1 mes
padre
commit
3e18dba9fd
Se han modificado 2 ficheros con 21 adiciones y 12 borrados
  1. 20 11
      ggml/src/ggml-cuda/mma.cuh
  2. 1 1
      ggml/src/ggml-cuda/mmq.cuh

+ 20 - 11
ggml/src/ggml-cuda/mma.cuh

@@ -437,18 +437,27 @@ namespace ggml_cuda_mma {
             xi[0] = xs[0];
         }
 #elif defined(AMD_WMMA_AVAILABLE)
-        if constexpr (I == 16 && J == 4) {
-            int64_t * xi = (int64_t *) t.x;
-            const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
-            xi[0] = xs[0];
-        }else if constexpr (I == 16 && J == 8) {
-            int64_t * xi = (int64_t *) t.x;
-            const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
-            xi[0] = xs[0];
+        if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
+            ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
+
+        } else if constexpr (std::is_same_v<T, int>) {
+            if constexpr (I == 16 && J == 4) {
+                int64_t * xi = (int64_t *) t.x;
+                const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
+                xi[0] = xs[0];
 
-            const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
-            xi[1] = xs1[0];
-        }else{
+            }else if constexpr (I == 16 && J == 8) {
+                int64_t * xi = (int64_t *) t.x;
+                const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
+                xi[0] = xs[0];
+
+                const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
+                xi[1] = xs1[0];
+
+            }else{
+                NO_DEVICE_CODE;
+            }
+        } else {
             NO_DEVICE_CODE;
         }
 #else

+ 1 - 1
ggml/src/ggml-cuda/mmq.cuh

@@ -3701,7 +3701,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int
     const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
     const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
     const size_t nbs_ids = mmq_x*sizeof(int);
-    const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
+    const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
     const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
     return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
 }