|
|
@@ -437,18 +437,27 @@ namespace ggml_cuda_mma {
|
|
|
xi[0] = xs[0];
|
|
|
}
|
|
|
#elif defined(AMD_WMMA_AVAILABLE)
|
|
|
- if constexpr (I == 16 && J == 4) {
|
|
|
- int64_t * xi = (int64_t *) t.x;
|
|
|
- const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
|
|
|
- xi[0] = xs[0];
|
|
|
- }else if constexpr (I == 16 && J == 8) {
|
|
|
- int64_t * xi = (int64_t *) t.x;
|
|
|
- const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
|
|
|
- xi[0] = xs[0];
|
|
|
+ if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
|
|
|
+ ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
|
|
|
+
|
|
|
+ } else if constexpr (std::is_same_v<T, int>) {
|
|
|
+ if constexpr (I == 16 && J == 4) {
|
|
|
+ int64_t * xi = (int64_t *) t.x;
|
|
|
+ const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
|
|
|
+ xi[0] = xs[0];
|
|
|
|
|
|
- const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
|
|
|
- xi[1] = xs1[0];
|
|
|
- }else{
|
|
|
+ }else if constexpr (I == 16 && J == 8) {
|
|
|
+ int64_t * xi = (int64_t *) t.x;
|
|
|
+ const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
|
|
|
+ xi[0] = xs[0];
|
|
|
+
|
|
|
+ const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
|
|
|
+ xi[1] = xs1[0];
|
|
|
+
|
|
|
+ }else{
|
|
|
+ NO_DEVICE_CODE;
|
|
|
+ }
|
|
|
+ } else {
|
|
|
NO_DEVICE_CODE;
|
|
|
}
|
|
|
#else
|