Browse Source

vulkan: set all memory allocations to high priority (#17624)

* vulkan: set all memory allocations to high priority

* gate by env var
Jeff Bolz 1 month ago
parent
commit
93bb92664e
1 changed files with 21 additions and 1 deletions
  1. 21 1
      ggml/src/ggml-vulkan/ggml-vulkan.cpp

+ 21 - 1
ggml/src/ggml-vulkan/ggml-vulkan.cpp

@@ -519,6 +519,7 @@ struct vk_device_struct {
     bool fp16;
     bool fp16;
     bool bf16;
     bool bf16;
     bool pipeline_robustness;
     bool pipeline_robustness;
+    bool memory_priority;
     vk::Device device;
     vk::Device device;
     uint32_t vendor_id;
     uint32_t vendor_id;
     vk::DriverId driver_id;
     vk::DriverId driver_id;
@@ -2369,7 +2370,13 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
 
 
     vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
     vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
 
 
-    const vk::MemoryAllocateFlagsInfo mem_flags_info { mem_flags };
+    const vk::MemoryPriorityAllocateInfoEXT mem_priority_info { 1.0f };
+
+    vk::MemoryAllocateFlagsInfo mem_flags_info { mem_flags };
+
+    if (device->memory_priority) {
+        mem_flags_info.setPNext(&mem_priority_info);
+    }
 
 
     for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
     for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
         const auto & req_flags = *it;
         const auto & req_flags = *it;
@@ -4340,6 +4347,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
 #endif
 #endif
             } else if (strcmp("VK_KHR_pipeline_executable_properties", properties.extensionName) == 0) {
             } else if (strcmp("VK_KHR_pipeline_executable_properties", properties.extensionName) == 0) {
                 pipeline_executable_properties_support = true;
                 pipeline_executable_properties_support = true;
+            } else if (strcmp("VK_EXT_memory_priority", properties.extensionName) == 0 &&
+                       getenv("GGML_VK_ENABLE_MEMORY_PRIORITY")) {
+                device->memory_priority = true;
             }
             }
         }
         }
 
 
@@ -4531,6 +4541,16 @@ static vk_device ggml_vk_get_device(size_t idx) {
             device_extensions.push_back("VK_EXT_pipeline_robustness");
             device_extensions.push_back("VK_EXT_pipeline_robustness");
         }
         }
 
 
+        VkPhysicalDeviceMemoryPriorityFeaturesEXT memory_priority_features;
+        memory_priority_features.pNext = nullptr;
+        memory_priority_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_PRIORITY_FEATURES_EXT;
+        memory_priority_features.memoryPriority = VK_FALSE;
+        if (device->memory_priority) {
+            last_struct->pNext = (VkBaseOutStructure *)&memory_priority_features;
+            last_struct = (VkBaseOutStructure *)&memory_priority_features;
+            device_extensions.push_back("VK_EXT_memory_priority");
+        }
+
         VkPhysicalDeviceSubgroupSizeControlFeaturesEXT subgroup_size_control_features;
         VkPhysicalDeviceSubgroupSizeControlFeaturesEXT subgroup_size_control_features;
         subgroup_size_control_features.pNext = nullptr;
         subgroup_size_control_features.pNext = nullptr;
         subgroup_size_control_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_FEATURES_EXT;
         subgroup_size_control_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_FEATURES_EXT;