Quellcode durchsuchen

CUDA: GEMM for FP32/FP16/BF16 and ne11 <= 16 (#15131)

* CUDA: GEMM for FP32/FP16/BF16 and ne11 <= 16
Johannes Gäßler vor 5 Monaten
Ursprung
Commit
1d72c84188

+ 10 - 2
ggml/src/ggml-cuda/common.cuh

@@ -233,9 +233,13 @@ typedef float2 dfloat2;
 #endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
 
 #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
-#define NEW_MMA_AVAILABLE
+#define TURING_MMA_AVAILABLE
 #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
 
+#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#define AMPERE_MMA_AVAILABLE
+#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+
 #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 #define CP_ASYNC_AVAILABLE
 #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
@@ -303,10 +307,14 @@ static bool amd_mfma_available(const int cc) {
 }
 
 // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
-static bool new_mma_available(const int cc) {
+static bool turing_mma_available(const int cc) {
     return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
 }
 
+static bool ampere_mma_available(const int cc) {
+    return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
+}
+
 static bool cp_async_available(const int cc) {
     return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
 }

+ 6 - 6
ggml/src/ggml-cuda/fattn-mma-f16.cuh

@@ -418,7 +418,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         float        * const __restrict__ KQ_max,
         float        * const __restrict__ KQ_rowsum,
         const int kb0) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
     typedef fattn_mma_f16_config<DKQ, DV> c;
 
 #ifdef CP_ASYNC_AVAILABLE
@@ -776,7 +776,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
     GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
     GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
     NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
 }
 
 template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
@@ -800,7 +800,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         const int jt,
         const int kb0_start,
         const int kb0_stop) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
     //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
     typedef fattn_mma_f16_config<DKQ, DV> c;
@@ -1196,7 +1196,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
     GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask);
     GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop);
     NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
 }
 
 template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla>
@@ -1223,7 +1223,7 @@ static __global__ void flash_attn_ext_f16(
                             const int32_t nb21, const int32_t nb22, const int64_t nb23,
                             const int32_t ne31, const int32_t ne32, const int32_t ne33,
                             const int32_t nb31, const int32_t nb32, const int64_t nb33) {
-#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
+#if defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)
 
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
@@ -1354,7 +1354,7 @@ static __global__ void flash_attn_ext_f16(
     GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
     GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
     NO_DEVICE_CODE;
-#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
+#endif // defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)
 }
 
 template <int DKQ, int DV, int ncols1, int ncols2>

+ 2 - 2
ggml/src/ggml-cuda/fattn.cu

@@ -327,7 +327,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
     const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
     const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (Q->ne[2] > 4*K->ne[2] && K->ne[1] >= 8192);
-    const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion &&
+    const bool mma_faster_for_bs1 = turing_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion &&
         (cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
     const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
     if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
@@ -340,7 +340,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     }
 
     // The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
-    if (fp16_mma_available(cc) && !new_mma_available(cc)) {
+    if (fp16_mma_available(cc) && !turing_mma_available(cc)) {
         ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
         return;
     }

+ 20 - 11
ggml/src/ggml-cuda/ggml-cuda.cu

@@ -22,8 +22,9 @@
 #include "ggml-cuda/fattn.cuh"
 #include "ggml-cuda/getrows.cuh"
 #include "ggml-cuda/im2col.cuh"
+#include "ggml-cuda/mmf.cuh"
 #include "ggml-cuda/mmq.cuh"
-#include "ggml-cuda/mmv.cuh"
+#include "ggml-cuda/mmvf.cuh"
 #include "ggml-cuda/mmvq.cuh"
 #include "ggml-cuda/norm.cuh"
 #include "ggml-cuda/opt-step-adamw.cuh"
@@ -2008,7 +2009,9 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
     const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE
         && ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;
 
-    bool use_mul_mat_vec   = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
+    bool use_mul_mat_vec_f = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
+        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
+    bool use_mul_mat_f     = !ggml_is_quantized(src0->type)
         && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
     bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
         && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
@@ -2028,14 +2031,18 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
             }
 
             const int cc            = ggml_cuda_info().devices[id].cc;
+            const int warp_size     = ggml_cuda_info().devices[id].warp_size;
             use_mul_mat_q           = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
-            use_mul_mat_vec         = use_mul_mat_vec           && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
+            use_mul_mat_f           = use_mul_mat_f             && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]);
+            use_mul_mat_vec_f       = use_mul_mat_vec_f         && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
             any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available(cc);
         }
     } else {
         const int cc            = ggml_cuda_info().devices[ctx.device].cc;
+        const int warp_size     = ggml_cuda_info().devices[ctx.device].warp_size;
         use_mul_mat_q           = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
-        use_mul_mat_vec         = use_mul_mat_vec           && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
+        use_mul_mat_f           = use_mul_mat_f             && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]);
+        use_mul_mat_vec_f       = use_mul_mat_vec_f         && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
         any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available(cc);
     }
 
@@ -2048,15 +2055,17 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
     //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
 
     //TODO update for generic tensor parallelism
-    const int cc                     = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+    const int cc                 = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
     bool use_batched_cublas_f16  = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
     bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
     bool use_batched_cublas_f32  = src0->type == GGML_TYPE_F32;
 
-    if (!split && use_mul_mat_vec) {
+    if (!split && use_mul_mat_vec_f) {
         // the custom F16 vector kernel can be used over batched cuBLAS GEMM
         // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
-        ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
+        ggml_cuda_mul_mat_vec_f(ctx, src0, src1, nullptr, dst);
+    } else if (!split && use_mul_mat_f) {
+        ggml_cuda_mul_mat_f(ctx, src0, src1, nullptr, dst);
     } else if (!split && use_mul_mat_vec_q) {
         ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
     } else if (!split && use_mul_mat_q) {
@@ -2065,8 +2074,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
         && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
         // general KQ + KQV multi-batch without FlashAttention
         ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
-    } else if (use_mul_mat_vec) {
-        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec, nullptr);
+    } else if (use_mul_mat_vec_f) {
+        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_f, nullptr);
     } else if (use_mul_mat_vec_q) {
         ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
     } else if (use_mul_mat_q) {
@@ -2094,7 +2103,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
             if (ggml_is_quantized(src0->type)) {
                 ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
             } else {
-                ggml_cuda_mul_mat_vec(ctx, src0, src1, ids, dst);
+                ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
             }
             return;
         }
@@ -3516,7 +3525,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
 #endif // FLASH_ATTN_AVAILABLE
             if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
                 const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
-                if (!new_mma_available(cc)) {
+                if (!turing_mma_available(cc)) {
                     return false;
                 }
                 const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];

+ 88 - 22
ggml/src/ggml-cuda/mma.cuh

@@ -23,13 +23,13 @@
 static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
     int ret = 0;
 
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
     asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
         : "=r"(ret) : "r"(x));
 #else
     GGML_UNUSED(x);
     NO_DEVICE_CODE;
-#endif // defined(NEW_MMA_AVAILABLE)
+#endif // defined(TURING_MMA_AVAILABLE)
     return ret;
 }
 
@@ -167,6 +167,38 @@ namespace ggml_cuda_mma {
         }
     };
 
+    template <int I_, int J_>
+    struct tile<I_, J_, nv_bfloat162> {
+        static constexpr int I  = I_;
+        static constexpr int J  = J_;
+        static constexpr int ne = I * J / WARP_SIZE;
+        nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
+
+        static __device__ __forceinline__ int get_i(const int l) {
+            if constexpr (I == 8 && J == 8) {
+                return threadIdx.x / 4;
+            } else if constexpr (I == 16 && J == 4) {
+                return l * 8 + threadIdx.x / 4;
+            } else if constexpr (I == 16 && J == 8) {
+                return (l % 2) * 8 + threadIdx.x / 4;
+            } else {
+                static_assert(I == -1 && J == -1, "template specialization not implemented");
+            }
+        }
+
+        static __device__ __forceinline__ int get_j(const int l) {
+            if constexpr (I == 8 && J == 8) {
+                return l * 4 + threadIdx.x % 4;
+            } else if constexpr (I == 16 && J == 4) {
+                return threadIdx.x % 4;
+            } else if constexpr (I == 16 && J == 8) {
+                return (l / 2) * 4 + threadIdx.x % 4;
+            } else {
+                static_assert(I == -1 && J == -1, "template specialization not implemented");
+            }
+        }
+    };
+
     template <int I, int J>
     static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
         tile<I, J/2, half2> ret;
@@ -209,7 +241,7 @@ namespace ggml_cuda_mma {
     template <typename T>
     static __device__ __forceinline__ void load_ldmatrix(
             tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
         int * xi = (int *) t.x;
         const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J;
         asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
@@ -217,13 +249,13 @@ namespace ggml_cuda_mma {
             : "l"(xs));
 #else
         load_generic(t, xs0, stride);
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
     }
 
     template <typename T>
     static __device__ __forceinline__ void load_ldmatrix(
             tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
         int * xi = (int *) t.x;
         const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
         asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
@@ -232,13 +264,13 @@ namespace ggml_cuda_mma {
 #else
         load_generic(xs0, stride);
         GGML_UNUSED(t);
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
     }
 
     template <typename T>
     static __device__ __forceinline__ void load_ldmatrix(
             tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
-#if defined(NEW_MMA_AVAILABLE)
+#if defined(TURING_MMA_AVAILABLE)
         int * xi = (int * ) t.x;
         const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
         asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
@@ -246,13 +278,13 @@ namespace ggml_cuda_mma {
             : "l"(xs));
 #else
         load_generic(t, xs0, stride);
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
     }
 
     template <typename T>
     static __device__ __forceinline__ void load_ldmatrix_trans(
             tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
         int * xi = (int * ) t.x;
         const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
         asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
@@ -263,12 +295,12 @@ namespace ggml_cuda_mma {
         GGML_UNUSED(xs0);
         GGML_UNUSED(stride);
         NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
     }
 
     static __device__ __forceinline__ void mma(
             tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
 #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
         asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
             : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
@@ -287,12 +319,12 @@ namespace ggml_cuda_mma {
         GGML_UNUSED(A);
         GGML_UNUSED(B);
         NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
     }
 
     static __device__ __forceinline__ void mma(
             tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
 #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
         asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
             : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
@@ -317,12 +349,12 @@ namespace ggml_cuda_mma {
         GGML_UNUSED(A);
         GGML_UNUSED(B);
         NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
     }
 
     static __device__ __forceinline__ void mma(
             tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
         const int * Axi = (const int *) A.x;
         const int * Bxi = (const int *) B.x;
         int       * Dxi = (int       *) D.x;
@@ -344,12 +376,12 @@ namespace ggml_cuda_mma {
         GGML_UNUSED(A);
         GGML_UNUSED(B);
         NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
     }
 
     static __device__ __forceinline__ void mma(
             tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
         const int * Axi = (const int *) A.x;
         const int * Bxi = (const int *) B.x;
         int       * Dxi = (int       *) D.x;
@@ -380,12 +412,29 @@ namespace ggml_cuda_mma {
         GGML_UNUSED(A);
         GGML_UNUSED(B);
         NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
+    }
+
+    static __device__ __forceinline__ void mma(
+            tile<16, 8, float> & D, const tile<16, 8, float> & A, const tile<8, 8, float> & B) {
+#ifdef AMPERE_MMA_AVAILABLE
+        const int * Axi = (const int *) A.x;
+        const int * Bxi = (const int *) B.x;
+        int       * Dxi = (int       *) D.x;
+        asm("mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
+#else
+        GGML_UNUSED(D);
+        GGML_UNUSED(A);
+        GGML_UNUSED(B);
+        NO_DEVICE_CODE;
+#endif // AMPERE_MMA_AVAILABLE
     }
 
     static __device__ __forceinline__ void mma(
             tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
         const int * Axi = (const int *) A.x;
         const int * Bxi = (const int *) B.x;
         int       * Dxi = (int       *) D.x;
@@ -407,12 +456,29 @@ namespace ggml_cuda_mma {
         GGML_UNUSED(A);
         GGML_UNUSED(B);
         NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
+    }
+
+    static __device__ __forceinline__ void mma(
+            tile<16, 8, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<8, 8, nv_bfloat162> & B) {
+#ifdef AMPERE_MMA_AVAILABLE
+        const int * Axi = (const int *) A.x;
+        const int * Bxi = (const int *) B.x;
+        int       * Dxi = (int       *) D.x;
+        asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
+#else
+        GGML_UNUSED(D);
+        GGML_UNUSED(A);
+        GGML_UNUSED(B);
+        NO_DEVICE_CODE;
+#endif // AMPERE_MMA_AVAILABLE
     }
 
     static __device__ __forceinline__ void mma(
             tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
         const int * Axi = (const int *) A.x;
         const int * Bxi = (const int *) B.x;
         int       * Dxi = (int       *) D.x;
@@ -443,7 +509,7 @@ namespace ggml_cuda_mma {
         GGML_UNUSED(A);
         GGML_UNUSED(B);
         NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
     }
 
     static __device__ __forceinline__ void mma(

+ 431 - 0
ggml/src/ggml-cuda/mmf.cu

@@ -0,0 +1,431 @@
+#include "ggml.h"
+#include "common.cuh"
+#include "mma.cuh"
+#include "mmf.cuh"
+
+using namespace ggml_cuda_mma;
+
+#define MMF_ROWS_PER_BLOCK 32
+
+template <typename T, int rows_per_block, int cols_per_block, int nwarps>
+__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
+static __global__ void mul_mat_f(
+        const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
+        const int ncols, const int nchannels_y, const int stride_row, const int stride_col_y, const int stride_col_dst,
+        const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+        const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+    typedef tile<16, 8, T>     tile_A;
+    typedef tile< 8, 8, T>     tile_B;
+    typedef tile<16, 8, float> tile_C;
+
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+    constexpr int tile_k_padded = warp_size + 4;
+    constexpr int ntA = rows_per_block / tile_A::I;
+    constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
+
+    const int row0        = blockIdx.x * rows_per_block;
+    const int channel_dst = blockIdx.y;
+    const int channel_x   = channel_dst / channel_ratio;
+    const int channel_y   = channel_dst;
+    const int sample_dst  = blockIdx.z;
+    const int sample_x    = sample_dst / sample_ratio;
+    const int sample_y    = sample_dst;
+
+    x   += int64_t(sample_x)  *stride_sample_x   + channel_x  *stride_channel_x   + row0*stride_row ;
+    y   += int64_t(sample_y)  *stride_sample_y   + channel_y  *stride_channel_y;
+    dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
+
+    const float2 * y2 = (const float2 *) y;
+
+    extern __shared__ char data_mmv[];
+
+    tile_C C[ntA][ntB];
+
+    T * tile_xy = (T *) data_mmv + threadIdx.y*(tile_A::I * tile_k_padded);
+
+    for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
+        tile_A A[ntA][warp_size / tile_A::J];
+#pragma unroll
+        for (int itA = 0; itA < ntA; ++itA) {
+#pragma unroll
+            for (int i = 0; i < tile_A::I; ++i) {
+                tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row  + col];
+            }
+#pragma unroll
+            for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
+                load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
+            }
+        }
+
+#pragma unroll
+        for (int itB = 0; itB < ntB; ++itB) {
+            if constexpr (std::is_same_v<T, float>) {
+#pragma unroll
+                for (int j0 = 0; j0 < tile_B::I; ++j0) {
+                    const int j = j0 + itB*tile_B::I;
+
+                    tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
+                }
+            } else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
+#pragma unroll
+                for (int j0 = 0; j0 < tile_B::I; ++j0) {
+                    const int j = j0 + itB*tile_B::I;
+
+                    const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
+                    tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
+                }
+            } else {
+                static_assert(std::is_same_v<T, void>, "unsupported type");
+            }
+#pragma unroll
+            for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
+                tile_B B;
+                load_ldmatrix(B, tile_xy + k0, tile_k_padded);
+#pragma unroll
+                for (int itA = 0; itA < ntA; ++itA) {
+                    mma(C[itA][itB], A[itA][k0/tile_B::J], B);
+                }
+            }
+        }
+    }
+
+    float * buf_iw = (float *) data_mmv;
+    constexpr int kiw = nwarps*rows_per_block + 4;
+
+    if (nwarps > 1) {
+        __syncthreads();
+    }
+#pragma unroll
+    for (int itB = 0; itB < ntB; ++itB) {
+#pragma unroll
+        for (int itA = 0; itA < ntA; ++itA) {
+#pragma unroll
+            for (int l = 0; l < tile_C::ne; ++l) {
+                const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
+                const int j = itB*tile_C::J + tile_C::get_j(l);
+                buf_iw[j*kiw + i] = C[itA][itB].x[l];
+            }
+        }
+    }
+
+    if (nwarps > 1) {
+        __syncthreads();
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+        if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
+            return;
+        }
+
+        float sum = 0.0f;
+        static_assert(rows_per_block == warp_size, "need loop/check");
+#pragma unroll
+        for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
+            const int i = i0 + threadIdx.x;
+
+            sum += buf_iw[j*kiw + i];
+        }
+        dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
+    }
+#else
+    NO_DEVICE_CODE;
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(ids); GGML_UNUSED(dst);
+    GGML_UNUSED(ncols); GGML_UNUSED(nchannels_y); GGML_UNUSED(stride_row); GGML_UNUSED(stride_col_y); GGML_UNUSED(stride_col_dst);
+    GGML_UNUSED(channel_ratio); GGML_UNUSED(stride_channel_x); GGML_UNUSED(stride_channel_y); GGML_UNUSED(stride_channel_dst);
+    GGML_UNUSED(sample_ratio); GGML_UNUSED(stride_sample_x); GGML_UNUSED(stride_sample_y); GGML_UNUSED(stride_sample_dst);
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+}
+
+template <typename T, int cols_per_block>
+static void mul_mat_f_cuda(
+        const T * x, const float * y, const int32_t * ids, float * dst,
+        const int64_t ncols_x, const int64_t nrows_x,
+        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+        cudaStream_t stream) {
+    typedef tile<16, 8, T>     tile_A;
+    typedef tile< 8, 8, T>     tile_B;
+    typedef tile<16, 8, float> tile_C;
+
+    GGML_ASSERT(!ids && "mul_mat_id not implemented");
+
+    GGML_ASSERT(ncols_x      % 2 == 0);
+    GGML_ASSERT(stride_row   % 2 == 0);
+    GGML_ASSERT(stride_col_y % 2 == 0);
+    GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
+    GGML_ASSERT(       nsamples_dst  % nsamples_x  == 0);
+    const int64_t channel_ratio = nchannels_dst / nchannels_x;
+    const int64_t sample_ratio  = nsamples_dst  / nsamples_x;
+
+    const int device = ggml_cuda_get_device();
+    const int warp_size = ggml_cuda_info().devices[device].warp_size;
+
+    int64_t nwarps_best     = 1;
+    int64_t niter_best      = (ncols_x + warp_size*2 - 1) / (warp_size*2);
+    int64_t max_block_size  = 256;
+    for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
+        const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
+        if (niter < niter_best) {
+            niter_best  = niter;
+            nwarps_best = nwarps;
+        }
+    }
+
+    constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
+    const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4;
+    const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
+    const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
+    const dim3 block_nums(nrows_x/rows_per_block, nchannels_dst, nsamples_dst);
+    const dim3 block_dims(warp_size, nwarps_best, 1);
+    switch (nwarps_best) {
+        case 1: {
+            mul_mat_f<T, rows_per_block, cols_per_block, 1><<<block_nums, block_dims, nbytes_shared, stream>>>
+                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case 2: {
+            mul_mat_f<T, rows_per_block, cols_per_block, 2><<<block_nums, block_dims, nbytes_shared, stream>>>
+                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case 3: {
+            mul_mat_f<T, rows_per_block, cols_per_block, 3><<<block_nums, block_dims, nbytes_shared, stream>>>
+                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case 4: {
+            mul_mat_f<T, rows_per_block, cols_per_block, 4><<<block_nums, block_dims, nbytes_shared, stream>>>
+                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case 5: {
+            mul_mat_f<T, rows_per_block, cols_per_block, 5><<<block_nums, block_dims, nbytes_shared, stream>>>
+                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case 6: {
+            mul_mat_f<T, rows_per_block, cols_per_block, 6><<<block_nums, block_dims, nbytes_shared, stream>>>
+                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case 7: {
+            mul_mat_f<T, rows_per_block, cols_per_block, 7><<<block_nums, block_dims, nbytes_shared, stream>>>
+                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case 8: {
+            mul_mat_f<T, rows_per_block, cols_per_block, 8><<<block_nums, block_dims, nbytes_shared, stream>>>
+                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        default: {
+            GGML_ABORT("fatal error");
+        } break;
+    }
+}
+
+template <typename T>
+static void mul_mat_f_switch_cols_per_block(
+        const T * x, const float * y, const int32_t * ids, float * dst,
+        const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
+        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+        cudaStream_t stream) {
+    switch (ncols_dst) {
+        case  1: {
+            mul_mat_f_cuda<T,  1>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case  2: {
+            mul_mat_f_cuda<T,  2>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case  3: {
+            mul_mat_f_cuda<T,  3>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case  4: {
+            mul_mat_f_cuda<T,  4>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case  5: {
+            mul_mat_f_cuda<T,  5>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case  6: {
+            mul_mat_f_cuda<T,  6>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case  7: {
+            mul_mat_f_cuda<T,  7>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case  8: {
+            mul_mat_f_cuda<T,  8>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case  9: {
+            mul_mat_f_cuda<T,  9>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case 10: {
+            mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case 11: {
+            mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case 12: {
+            mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case 13: {
+            mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case 14: {
+            mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case 15: {
+            mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case 16: {
+            mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        default: {
+            GGML_ABORT("fatal error");
+        } break;
+    }
+}
+
+void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
+    GGML_ASSERT(        src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(!ids ||  ids->type == GGML_TYPE_I32);
+    GGML_ASSERT(         dst->type == GGML_TYPE_F32);
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    const size_t ts_src0 = ggml_type_size(src0->type);
+    const size_t ts_src1 = ggml_type_size(src1->type);
+    const size_t ts_dst  = ggml_type_size(dst->type);
+
+    GGML_ASSERT(ne13 == ne3);
+
+    GGML_ASSERT(        nb00       == ts_src0);
+    GGML_ASSERT(        nb10       == ts_src1);
+    GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
+    GGML_ASSERT(        nb0        == ts_dst);
+
+    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+    const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
+
+    const float   * src1_d =       (const float   *) src1->data;
+    const int32_t *  ids_d = ids ? (const int32_t *)  ids->data : nullptr;
+    float         *  dst_d =       (float         *)  dst->data;
+
+    const int64_t s01 = src0->nb[1] / ts_src0;
+    const int64_t s11 = src1->nb[1] / ts_src1;
+    const int64_t s1  =  dst->nb[1] / ts_dst;
+    const int64_t s02 = src0->nb[2] / ts_src0;
+    const int64_t s12 = src1->nb[2] / ts_src1;
+    const int64_t s2  =  dst->nb[2] / ts_dst;
+    const int64_t s03 = src0->nb[3] / ts_src0;
+    const int64_t s13 = src1->nb[3] / ts_src1;
+    const int64_t s3  =  dst->nb[3] / ts_dst;
+
+    // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
+    const int64_t ncols_dst          = ids ? ne2  : ne1;
+    const int64_t nchannels_y        = ids ? ne11 : ne12;
+    const int64_t nchannels_dst      = ids ? ne1  : ne2;
+    const int64_t stride_channel_dst = ids ? s1   : s2;
+    const int64_t stride_channel_y   = ids ? s11  : s12;
+
+    GGML_ASSERT(!ids || ncols_dst == 1);
+
+    switch (src0->type) {
+        case GGML_TYPE_F32: {
+            const float * src0_d = (const float *) src0->data;
+            constexpr int vals_per_T = 1;
+            mul_mat_f_switch_cols_per_block(
+                src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
+                ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
+                ne03,              ne3,           s03/vals_per_T, s13,              s3,                 ctx.stream());
+        } break;
+        case GGML_TYPE_F16: {
+            const half2 * src0_d = (const half2 *) src0->data;
+            constexpr int vals_per_T = 2;
+            mul_mat_f_switch_cols_per_block(
+                src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
+                ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
+                ne03,              ne3,           s03/vals_per_T, s13,              s3,                 ctx.stream());
+        } break;
+        case GGML_TYPE_BF16: {
+            const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
+            constexpr int vals_per_T = 2;
+            mul_mat_f_switch_cols_per_block(
+                src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
+                ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
+                ne03,              ne3,           s03/vals_per_T, s13,              s3,                 ctx.stream());
+        } break;
+        default:
+            GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
+    }
+}
+
+bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, int64_t ne11) {
+    if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) {
+        return false;
+    }
+    if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
+        return false;
+    }
+    if (ne11 > 16) {
+        return false;
+    }
+    switch (type) {
+        case GGML_TYPE_F32:
+            return ampere_mma_available(cc);
+        case GGML_TYPE_F16:
+            return turing_mma_available(cc);
+        case GGML_TYPE_BF16:
+            return ampere_mma_available(cc);
+        default:
+            return false;
+    }
+}

+ 5 - 0
ggml/src/ggml-cuda/mmf.cuh

@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
+
+bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, int64_t ne11);

+ 1 - 1
ggml/src/ggml-cuda/mmq.cu

@@ -310,7 +310,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
         return false;
     }
 
-    if (new_mma_available(cc)) {
+    if (turing_mma_available(cc)) {
         return true;
     }
 

Datei-Diff unterdrückt, da er zu groß ist
+ 131 - 131
ggml/src/ggml-cuda/mmq.cuh


+ 49 - 45
ggml/src/ggml-cuda/mmv.cu → ggml/src/ggml-cuda/mmvf.cu

@@ -1,9 +1,9 @@
 #include "ggml.h"
 #include "common.cuh"
-#include "mmv.cuh"
+#include "mmvf.cuh"
 
 template <typename T, typename type_acc, int ncols_dst, int block_size>
-static __global__ void mul_mat_vec(
+static __global__ void mul_mat_vec_f(
         const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
         const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
         const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
@@ -37,7 +37,7 @@ static __global__ void mul_mat_vec(
 
     float sumf[ncols_dst] = {0.0f};
 
-    if constexpr (std::is_same<T, float>::value) {
+    if constexpr (std::is_same_v<T, float>) {
         const float2 * x2 = (const float2 *) x;
 
         for (int col2 = tid; col2 < ncols2; col2 += block_size) {
@@ -50,10 +50,10 @@ static __global__ void mul_mat_vec(
                 sumf[j] += tmpx.y*tmpy.y;
             }
         }
-    } else if constexpr (std::is_same<T, half>::value) {
+    } else if constexpr (std::is_same_v<T, half>) {
         const half2 * x2 = (const half2 *) x;
 
-        if (std::is_same<type_acc, float>::value) {
+        if (std::is_same_v<type_acc, float>) {
             for (int col2 = tid; col2 < ncols2; col2 += block_size) {
                 const float2 tmpx = __half22float2(x2[col2]);
 
@@ -86,7 +86,7 @@ static __global__ void mul_mat_vec(
             NO_DEVICE_CODE;
 #endif // FP16_AVAILABLE
         }
-    } else if constexpr (std::is_same<T, nv_bfloat16>::value) {
+    } else if constexpr (std::is_same_v<T, nv_bfloat16>) {
         const int * x2 = (const int *) x;
         for (int col2 = tid; col2 < ncols2; col2 += block_size) {
             const int tmpx = x2[col2];
@@ -98,7 +98,7 @@ static __global__ void mul_mat_vec(
             }
         }
     } else {
-        static_assert(std::is_same<T, void>::value, "unsupported type");
+        static_assert(std::is_same_v<T, void>, "unsupported type");
     }
 
 #pragma unroll
@@ -126,7 +126,7 @@ static __global__ void mul_mat_vec(
 }
 
 template <typename T, typename type_acc, int ncols_dst>
-static void launch_mul_mat_vec_cuda(
+static void launch_mul_mat_vec_f_cuda(
         const T * x, const float * y, const int32_t * ids, float * dst,
         const int64_t ncols, const int64_t nrows,
         const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
@@ -141,11 +141,9 @@ static void launch_mul_mat_vec_cuda(
     GGML_ASSERT(       nsamples_dst  % nsamples_x  == 0);
     const int64_t channel_ratio = nchannels_dst / nchannels_x;
     const int64_t sample_ratio  = nsamples_dst  / nsamples_x;
-    int device;
-    int warp_size;
 
-    CUDA_CHECK(cudaGetDevice(&device));
-    warp_size = ggml_cuda_info().devices[device].warp_size;
+    const int device = ggml_cuda_get_device();
+    const int warp_size = ggml_cuda_info().devices[device].warp_size;
 
     int64_t block_size_best = warp_size;
     int64_t niter_best      = (ncols + 2*warp_size - 1) / (2*warp_size);
@@ -161,54 +159,54 @@ static void launch_mul_mat_vec_cuda(
         }
     }
 
-    const int smem = warp_size*sizeof(float);
+    const int nbytes_shared = warp_size*sizeof(float);
     const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
     const dim3 block_dims(block_size_best, 1, 1);
     switch (block_size_best) {
         case   32: {
-            mul_mat_vec<T, type_acc, ncols_dst,  32><<<block_nums, block_dims, smem, stream>>>
+            mul_mat_vec_f<T, type_acc, ncols_dst,  32><<<block_nums, block_dims, nbytes_shared, stream>>>
                 (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case   64: {
-            mul_mat_vec<T, type_acc, ncols_dst,  64><<<block_nums, block_dims, smem, stream>>>
+            mul_mat_vec_f<T, type_acc, ncols_dst,  64><<<block_nums, block_dims, nbytes_shared, stream>>>
                 (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case   96: {
-            mul_mat_vec<T, type_acc, ncols_dst,  96><<<block_nums, block_dims, smem, stream>>>
+            mul_mat_vec_f<T, type_acc, ncols_dst,  96><<<block_nums, block_dims, nbytes_shared, stream>>>
                 (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  128: {
-            mul_mat_vec<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, smem, stream>>>
+            mul_mat_vec_f<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
                 (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  160: {
-            mul_mat_vec<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, smem, stream>>>
+            mul_mat_vec_f<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, nbytes_shared, stream>>>
                 (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  192: {
-            mul_mat_vec<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, smem, stream>>>
+            mul_mat_vec_f<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, nbytes_shared, stream>>>
                 (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  224: {
-            mul_mat_vec<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, smem, stream>>>
+            mul_mat_vec_f<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, nbytes_shared, stream>>>
                 (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  256: {
-            mul_mat_vec<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, smem, stream>>>
+            mul_mat_vec_f<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
                 (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
@@ -220,7 +218,7 @@ static void launch_mul_mat_vec_cuda(
 }
 
 template <typename T, typename type_acc>
-static void mul_mat_vec_cuda_switch_ncols_dst(
+static void mul_mat_vec_f_cuda_switch_ncols_dst(
         const T * x, const float * y, const int32_t * ids, float * dst,
         const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
         const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
@@ -230,49 +228,49 @@ static void mul_mat_vec_cuda_switch_ncols_dst(
         cudaStream_t stream) {
     switch (ncols_dst) {
         case 1:
-            launch_mul_mat_vec_cuda<T, type_acc, 1>
+            launch_mul_mat_vec_f_cuda<T, type_acc, 1>
                 (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
                  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case 2:
-            launch_mul_mat_vec_cuda<T, type_acc, 2>
+            launch_mul_mat_vec_f_cuda<T, type_acc, 2>
                 (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
                  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case 3:
-            launch_mul_mat_vec_cuda<T, type_acc, 3>
+            launch_mul_mat_vec_f_cuda<T, type_acc, 3>
                 (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
                  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case 4:
-            launch_mul_mat_vec_cuda<T, type_acc, 4>
+            launch_mul_mat_vec_f_cuda<T, type_acc, 4>
                 (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
                  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case 5:
-            launch_mul_mat_vec_cuda<T, type_acc, 5>
+            launch_mul_mat_vec_f_cuda<T, type_acc, 5>
                 (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
                  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case 6:
-            launch_mul_mat_vec_cuda<T, type_acc, 6>
+            launch_mul_mat_vec_f_cuda<T, type_acc, 6>
                 (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
                  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case 7:
-            launch_mul_mat_vec_cuda<T, type_acc, 7>
+            launch_mul_mat_vec_f_cuda<T, type_acc, 7>
                 (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
                  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case 8:
-            launch_mul_mat_vec_cuda<T, type_acc, 8>
+            launch_mul_mat_vec_f_cuda<T, type_acc, 8>
                 (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
                  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
@@ -284,7 +282,7 @@ static void mul_mat_vec_cuda_switch_ncols_dst(
 }
 
 template<typename T>
-static void mul_mat_vec_cuda(
+static void mul_mat_vec_f_cuda(
         const T * x, const float * y, const int32_t * ids, float * dst,
         const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
         const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
@@ -292,22 +290,22 @@ static void mul_mat_vec_cuda(
         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
         const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
         enum ggml_prec prec, cudaStream_t stream) {
-    if constexpr(std::is_same<T, half>::value) {
+    if constexpr(std::is_same_v<T, half>) {
         if (prec == GGML_PREC_DEFAULT) {
-            mul_mat_vec_cuda_switch_ncols_dst<T, half>
+            mul_mat_vec_f_cuda_switch_ncols_dst<T, half>
                 (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
                  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             return;
         }
     }
-    mul_mat_vec_cuda_switch_ncols_dst<T, float>
+    mul_mat_vec_f_cuda_switch_ncols_dst<T, float>
         (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
          nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
          stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
 }
 
-void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
+void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
     GGML_ASSERT(        src1->type == GGML_TYPE_F32);
     GGML_ASSERT(!ids ||  ids->type == GGML_TYPE_I32);
     GGML_ASSERT(         dst->type == GGML_TYPE_F32);
@@ -355,19 +353,19 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
     switch (src0->type) {
         case GGML_TYPE_F32: {
             const float * src0_d = (const float *) src0->data;
-            mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
+            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
                 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
                 ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
         } break;
         case GGML_TYPE_F16: {
             const half * src0_d = (const half *) src0->data;
-            mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
+            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
                 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
                 ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
         } break;
         case GGML_TYPE_BF16: {
             const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
-            mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
+            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
                 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
                 ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
         } break;
@@ -376,7 +374,7 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
     }
 }
 
-void ggml_cuda_op_mul_mat_vec(
+void ggml_cuda_op_mul_mat_vec_f(
     ggml_backend_cuda_context & ctx,
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
     const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
@@ -414,19 +412,19 @@ void ggml_cuda_op_mul_mat_vec(
     switch (src0->type) {
         case GGML_TYPE_F32: {
             const float * src0_d = (const float *) src0_dd_i;
-            mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
         } break;
         case GGML_TYPE_F16: {
             const half * src0_d = (const half *) src0_dd_i;
-            mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
         } break;
         case GGML_TYPE_BF16: {
             const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
-            mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
         } break;
@@ -442,15 +440,15 @@ void ggml_cuda_op_mul_mat_vec(
     GGML_UNUSED(src1_padded_row_size);
 }
 
-bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) {
+bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) {
     if (src0_ne[0] % 2 != 0) {
         return false;
     }
     switch (type) {
         case GGML_TYPE_F32:
             if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
-                if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
-                    return ne11 <= 8;
+                if (ampere_mma_available(cc)) {
+                    return ne11 <= 3;
                 }
                 if (cc >= GGML_CUDA_CC_TURING) {
                     return ne11 <= 4;
@@ -466,6 +464,9 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_
         case GGML_TYPE_F16:
             if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
                 const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
+                if (ampere_mma_available(cc)) {
+                    return src0_small && ne11 == 1;
+                }
                 if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
                     return src0_small && ne11 <= 4;
                 }
@@ -486,6 +487,9 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_
         case GGML_TYPE_BF16:
             if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
                 const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
+                if (ampere_mma_available(cc)) {
+                    return src0_small && ne11 == 1;
+                }
                 if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
                     return src0_small && ne11 <= 4;
                 }

+ 3 - 3
ggml/src/ggml-cuda/mmv.cuh → ggml/src/ggml-cuda/mmvf.cuh

@@ -1,11 +1,11 @@
 #include "common.cuh"
 
-void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
+void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
 
-void ggml_cuda_op_mul_mat_vec(
+void ggml_cuda_op_mul_mat_vec_f(
     ggml_backend_cuda_context & ctx,
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
     const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
     const int64_t src1_padded_row_size, cudaStream_t stream);
 
-bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11);
+bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11);

+ 1 - 0
ggml/src/ggml-cuda/vendors/hip.h

@@ -200,6 +200,7 @@
 #endif
 
 typedef hip_bfloat16 nv_bfloat16;
+typedef short2 nv_bfloat162; // FIXME there is no 2x BF16 type being defined in bfloat16.h, ad-hoc compilation fix
 
 typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
 typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));

+ 2 - 1
ggml/src/ggml-cuda/vendors/musa.h

@@ -137,4 +137,5 @@
 #define cudaStreamEndCapture musaStreamEndCapture
 #define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
 
-typedef mt_bfloat16 nv_bfloat16;
+typedef __mt_bfloat16 nv_bfloat16;
+typedef __mt_bfloat162 nv_bfloat162;

Einige Dateien werden nicht angezeigt, da zu viele Dateien in diesem Diff geändert wurden.