فهرست منبع

Add check for VK_KHR_portability_enumeration for MoltenVK support

0cc4m 1 سال پیش
والد
کامیت
f50db6ae0b
1فایلهای تغییر یافته به همراه30 افزوده شده و 9 حذف شده
  1. 30 9
      ggml-vulkan.cpp

+ 30 - 9
ggml-vulkan.cpp

@@ -1100,24 +1100,45 @@ static void ggml_vk_instance_init() {
 #endif
 #endif
 
 
     vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION };
     vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION };
-    const std::vector<const char*> layers = {
+
+    const std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties();
+#ifdef __APPLE__
+    bool portability_enumeration_ext = false;
+    // Check for portability enumeration extension for MoltenVK support
+    for (const auto& properties : instance_extensions) {
+        if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) {
+            portability_enumeration_ext = true;
+            break;
+        }
+    }
+    if (!portability_enumeration_ext) {
+        std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl;
+    }
+#endif
+
+    std::vector<const char*> layers = {
 #ifdef GGML_VULKAN_VALIDATE
 #ifdef GGML_VULKAN_VALIDATE
         "VK_LAYER_KHRONOS_validation",
         "VK_LAYER_KHRONOS_validation",
 #endif
 #endif
     };
     };
-    const std::vector<const char*> extensions = {
+    std::vector<const char*> extensions = {
 #ifdef GGML_VULKAN_VALIDATE
 #ifdef GGML_VULKAN_VALIDATE
         "VK_EXT_validation_features",
         "VK_EXT_validation_features",
 #endif
 #endif
+    };
 #ifdef __APPLE__
 #ifdef __APPLE__
-        "VK_KHR_portability_enumeration",
+    if (portability_enumeration_ext) {
+        extensions.push_back("VK_KHR_portability_enumeration");
+    }
 #endif
 #endif
-    };
-    vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags(), &app_info, layers, extensions);
+    vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions);
 #ifdef __APPLE__
 #ifdef __APPLE__
-    instance_create_info.flags = vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR;
+    if (portability_enumeration_ext) {
+        instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR;
+    }
 #endif
 #endif
 
 
+
 #ifdef GGML_VULKAN_VALIDATE
 #ifdef GGML_VULKAN_VALIDATE
     const std::vector<vk::ValidationFeatureEnableEXT> features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices };
     const std::vector<vk::ValidationFeatureEnableEXT> features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices };
     vk::ValidationFeaturesEXT validation_features = {
     vk::ValidationFeaturesEXT validation_features = {
@@ -1175,12 +1196,12 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
     vk_instance.devices[idx] = std::make_shared<vk_device>();
     vk_instance.devices[idx] = std::make_shared<vk_device>();
     ctx->device = vk_instance.devices[idx];
     ctx->device = vk_instance.devices[idx];
     ctx->device.lock()->physical_device = devices[dev_num];
     ctx->device.lock()->physical_device = devices[dev_num];
-    std::vector<vk::ExtensionProperties> ext_props = ctx->device.lock()->physical_device.enumerateDeviceExtensionProperties();
+    const std::vector<vk::ExtensionProperties> ext_props = ctx->device.lock()->physical_device.enumerateDeviceExtensionProperties();
 
 
     bool maintenance4_support = false;
     bool maintenance4_support = false;
 
 
     // Check if maintenance4 is supported
     // Check if maintenance4 is supported
-    for (auto properties : ext_props) {
+    for (const auto& properties : ext_props) {
         if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
         if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
             maintenance4_support = true;
             maintenance4_support = true;
         }
         }
@@ -1211,7 +1232,7 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
     bool fp16_storage = false;
     bool fp16_storage = false;
     bool fp16_compute = false;
     bool fp16_compute = false;
 
 
-    for (auto properties : ext_props) {
+    for (const auto& properties : ext_props) {
         if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
         if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
             fp16_storage = true;
             fp16_storage = true;
         } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
         } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {