Bläddra i källkod

Refactor validation and enumeration platform checks into functions to clean up ggml_vk_instance_init()

0cc4m 1 år sedan
förälder
incheckning
bb9dcd560a
1 ändrade filer med 62 tillägg och 37 borttagningar
  1. 62 37
      ggml-vulkan.cpp

+ 62 - 37
ggml-vulkan.cpp

@@ -1091,7 +1091,10 @@ static void ggml_vk_print_gpu_info(size_t idx) {
     }
 }
 
-static void ggml_vk_instance_init() {
+static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
+static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
+
+void ggml_vk_instance_init() {
     if (vk_instance_initialized) {
         return;
     }
@@ -1102,54 +1105,40 @@ static void ggml_vk_instance_init() {
     vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION };
 
     const std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties();
-#ifdef __APPLE__
-    bool portability_enumeration_ext = false;
-    // Check for portability enumeration extension for MoltenVK support
-    for (const auto& properties : instance_extensions) {
-        if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) {
-            portability_enumeration_ext = true;
-            break;
-        }
+    const bool validation_ext = ggml_vk_instance_validation_ext_available(instance_extensions);
+    const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions);
+
+    std::vector<const char*> layers;
+
+    if (validation_ext) {
+        layers.push_back("VK_LAYER_KHRONOS_validation");
     }
-    if (!portability_enumeration_ext) {
-        std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl;
+    std::vector<const char*> extensions;
+    if (validation_ext) {
+        extensions.push_back("VK_EXT_validation_features");
     }
-#endif
-
-    std::vector<const char*> layers = {
-#ifdef GGML_VULKAN_VALIDATE
-        "VK_LAYER_KHRONOS_validation",
-#endif
-    };
-    std::vector<const char*> extensions = {
-#ifdef GGML_VULKAN_VALIDATE
-        "VK_EXT_validation_features",
-#endif
-    };
-#ifdef __APPLE__
     if (portability_enumeration_ext) {
         extensions.push_back("VK_KHR_portability_enumeration");
     }
-#endif
     vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions);
-#ifdef __APPLE__
     if (portability_enumeration_ext) {
         instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR;
     }
-#endif
 
+    std::vector<vk::ValidationFeatureEnableEXT> features_enable;
+    vk::ValidationFeaturesEXT validation_features;
 
-#ifdef GGML_VULKAN_VALIDATE
-    const std::vector<vk::ValidationFeatureEnableEXT> features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices };
-    vk::ValidationFeaturesEXT validation_features = {
-        features_enable,
-        {},
-    };
-    validation_features.setPNext(nullptr);
-    instance_create_info.setPNext(&validation_features);
+    if (validation_ext) {
+        features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices };
+        validation_features = {
+            features_enable,
+            {},
+        };
+        validation_features.setPNext(nullptr);
+        instance_create_info.setPNext(&validation_features);
 
-    std::cerr << "ggml_vulkan: Validation layers enabled" << std::endl;
-#endif
+        std::cerr << "ggml_vulkan: Validation layers enabled" << std::endl;
+    }
     vk_instance.instance = vk::createInstance(instance_create_info);
 
     memset(vk_instance.initialized, 0, sizeof(bool) * GGML_VK_MAX_DEVICES);
@@ -5329,6 +5318,42 @@ GGML_CALL int ggml_backend_vk_reg_devices() {
     return vk_instance.device_indices.size();
 }
 
+// Extension availability
+static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions) {
+#ifdef GGML_VULKAN_VALIDATE
+    bool portability_enumeration_ext = false;
+    // Check for portability enumeration extension for MoltenVK support
+    for (const auto& properties : instance_extensions) {
+        if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) {
+            return true;
+        }
+    }
+    if (!portability_enumeration_ext) {
+        std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl;
+    }
+#endif
+    return false;
+
+    UNUSED(instance_extensions);
+}
+static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions) {
+#ifdef __APPLE__
+    bool portability_enumeration_ext = false;
+    // Check for portability enumeration extension for MoltenVK support
+    for (const auto& properties : instance_extensions) {
+        if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) {
+            return true;
+        }
+    }
+    if (!portability_enumeration_ext) {
+        std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl;
+    }
+#endif
+    return false;
+
+    UNUSED(instance_extensions);
+}
+
 // checks
 
 #ifdef GGML_VULKAN_CHECK_RESULTS