|
|
@@ -2976,19 +2976,19 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|
|
case GGML_OP_LEAKY_RELU:
|
|
|
case GGML_OP_RWKV_WKV:
|
|
|
return true;
|
|
|
- case GGML_OP_FLASH_ATTN_EXT:
|
|
|
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
|
|
- return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128;
|
|
|
-#else
|
|
|
+ case GGML_OP_FLASH_ATTN_EXT: {
|
|
|
+ if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
if (op->src[0]->ne[0] == 128) {
|
|
|
return true;
|
|
|
}
|
|
|
- if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
|
|
|
+ if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
|
|
|
return true;
|
|
|
}
|
|
|
- return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
|
|
|
- op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
|
|
|
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
|
|
+ const int cc = ggml_cuda_info().devices[cuda_ctx->device].cc;
|
|
|
+ return cc >= CC_VOLTA && cc < CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
|
|
|
+ }
|
|
|
case GGML_OP_CROSS_ENTROPY_LOSS:
|
|
|
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
|
|
case GGML_OP_OPT_STEP_ADAMW:
|