|
@@ -136,7 +136,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
|
|
|
return res;
|
|
return res;
|
|
|
#else
|
|
#else
|
|
|
|
|
|
|
|
-#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
|
|
|
|
|
|
|
+#if !defined(GGML_USE_HIPBLAS)
|
|
|
cudaError_t err;
|
|
cudaError_t err;
|
|
|
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
|
|
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
|
|
|
{
|
|
{
|
|
@@ -149,7 +149,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
|
|
|
return err;
|
|
return err;
|
|
|
#else
|
|
#else
|
|
|
return cudaMalloc(ptr, size);
|
|
return cudaMalloc(ptr, size);
|
|
|
-#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
|
|
|
|
|
|
|
+#endif // !defined(GGML_USE_HIPBLAS)
|
|
|
|
|
|
|
|
#endif
|
|
#endif
|
|
|
}
|
|
}
|
|
@@ -2830,6 +2830,12 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|
|
if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
|
|
if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
|
|
|
return false;
|
|
return false;
|
|
|
}
|
|
}
|
|
|
|
|
+#ifdef GGML_USE_MUSA
|
|
|
|
|
+ if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
|
|
|
|
|
+ !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
+#endif // GGML_USE_MUSA
|
|
|
switch (a->type) {
|
|
switch (a->type) {
|
|
|
case GGML_TYPE_F32:
|
|
case GGML_TYPE_F32:
|
|
|
case GGML_TYPE_F16:
|
|
case GGML_TYPE_F16:
|
|
@@ -2853,6 +2859,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|
|
case GGML_TYPE_IQ3_XXS:
|
|
case GGML_TYPE_IQ3_XXS:
|
|
|
case GGML_TYPE_IQ4_NL:
|
|
case GGML_TYPE_IQ4_NL:
|
|
|
case GGML_TYPE_IQ4_XS:
|
|
case GGML_TYPE_IQ4_XS:
|
|
|
|
|
+#ifdef GGML_USE_MUSA
|
|
|
|
|
+ if (a->type == GGML_TYPE_Q3_K) {
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
+#endif // GGML_USE_MUSA
|
|
|
return true;
|
|
return true;
|
|
|
default:
|
|
default:
|
|
|
return false;
|
|
return false;
|
|
@@ -2978,6 +2989,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|
|
case GGML_OP_RWKV_WKV:
|
|
case GGML_OP_RWKV_WKV:
|
|
|
return true;
|
|
return true;
|
|
|
case GGML_OP_FLASH_ATTN_EXT: {
|
|
case GGML_OP_FLASH_ATTN_EXT: {
|
|
|
|
|
+#ifndef FLASH_ATTN_AVAILABLE
|
|
|
|
|
+ return false;
|
|
|
|
|
+#endif
|
|
|
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
|
|
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
|
|
|
return true;
|
|
return true;
|
|
|
}
|
|
}
|