|
|
@@ -408,6 +408,8 @@ struct vk_device_struct {
|
|
|
bool subgroup_ballot;
|
|
|
bool subgroup_clustered;
|
|
|
bool multi_add;
|
|
|
+ bool shader_int64;
|
|
|
+ bool buffer_device_address;
|
|
|
|
|
|
bool add_rms_fusion;
|
|
|
uint32_t partials_binding_alignment;
|
|
|
@@ -655,6 +657,7 @@ struct vk_buffer_struct {
|
|
|
vk::MemoryPropertyFlags memory_property_flags;
|
|
|
void * ptr;
|
|
|
size_t size = 0;
|
|
|
+ vk::DeviceAddress bda_addr {};
|
|
|
|
|
|
vk_device device;
|
|
|
|
|
|
@@ -987,6 +990,7 @@ struct vk_op_argsort_push_constants {
|
|
|
};
|
|
|
|
|
|
struct vk_op_im2col_push_constants {
|
|
|
+ uint64_t dst_addr;
|
|
|
uint32_t batch_offset; uint32_t offset_delta;
|
|
|
uint32_t IC;
|
|
|
uint32_t IW; uint32_t IH;
|
|
|
@@ -1000,6 +1004,7 @@ struct vk_op_im2col_push_constants {
|
|
|
};
|
|
|
|
|
|
struct vk_op_im2col_3d_push_constants {
|
|
|
+ uint64_t dst_addr;
|
|
|
uint32_t nb10;
|
|
|
uint32_t nb11;
|
|
|
uint32_t nb12;
|
|
|
@@ -2012,10 +2017,17 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
|
|
|
return buf;
|
|
|
}
|
|
|
|
|
|
+ vk::BufferUsageFlags usage_flags = vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst;
|
|
|
+ vk::MemoryAllocateFlags mem_flags {};
|
|
|
+ if (device->buffer_device_address) {
|
|
|
+ usage_flags |= vk::BufferUsageFlagBits::eShaderDeviceAddress;
|
|
|
+ mem_flags |= vk::MemoryAllocateFlagBits::eDeviceAddress;
|
|
|
+ }
|
|
|
+
|
|
|
vk::BufferCreateInfo buffer_create_info{
|
|
|
vk::BufferCreateFlags(),
|
|
|
size,
|
|
|
- vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst,
|
|
|
+ usage_flags,
|
|
|
vk::SharingMode::eExclusive,
|
|
|
0,
|
|
|
nullptr,
|
|
|
@@ -2027,6 +2039,8 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
|
|
|
|
|
|
vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
|
|
|
|
|
|
+ const vk::MemoryAllocateFlagsInfo mem_flags_info { mem_flags };
|
|
|
+
|
|
|
for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
|
|
|
const auto & req_flags = *it;
|
|
|
|
|
|
@@ -2038,7 +2052,7 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
|
|
|
buf->memory_property_flags = req_flags;
|
|
|
|
|
|
try {
|
|
|
- buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index });
|
|
|
+ buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index, &mem_flags_info });
|
|
|
break;
|
|
|
} catch (const vk::SystemError& e) {
|
|
|
// loop and retry
|
|
|
@@ -2066,6 +2080,11 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
|
|
|
buf->device = device;
|
|
|
buf->size = size;
|
|
|
|
|
|
+ if (device->buffer_device_address) {
|
|
|
+ const vk::BufferDeviceAddressInfo addressInfo(buf->buffer);
|
|
|
+ buf->bda_addr = device->device.getBufferAddress(addressInfo);
|
|
|
+ }
|
|
|
+
|
|
|
#ifdef GGML_VULKAN_MEMORY_DEBUG
|
|
|
device->memory_logger->log_allocation(buf, size);
|
|
|
#endif
|
|
|
@@ -3532,14 +3551,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
|
|
|
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
|
|
|
|
|
|
- ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
|
|
|
- ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32_len, im2col_3d_f32_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
|
|
|
- if (device->float_controls_rte_fp16) {
|
|
|
- ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
|
|
|
- ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte_len, im2col_3d_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
|
|
|
+#define IM2COL(bda) \
|
|
|
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
|
|
|
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
|
|
|
+ if (device->float_controls_rte_fp16) { \
|
|
|
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte ## bda ## _len, im2col_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
|
|
|
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
|
|
|
+ } else { \
|
|
|
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
|
|
|
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
|
|
|
+ }
|
|
|
+ if (device->shader_int64 && device->buffer_device_address) {
|
|
|
+ IM2COL(_bda)
|
|
|
} else {
|
|
|
- ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
|
|
|
- ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_len, im2col_3d_f32_f16_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
|
|
|
+ IM2COL()
|
|
|
}
|
|
|
|
|
|
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
|
|
|
@@ -4017,6 +4042,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
device->vendor_id != VK_VENDOR_ID_INTEL &&
|
|
|
getenv("GGML_VK_DISABLE_MULTI_ADD") == nullptr;
|
|
|
|
|
|
+ device->shader_int64 = device_features2.features.shaderInt64;
|
|
|
+ device->buffer_device_address = vk12_features.bufferDeviceAddress;
|
|
|
+
|
|
|
if (device->subgroup_size_control) {
|
|
|
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
|
|
|
device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
|
|
|
@@ -8635,6 +8663,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
|
|
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
|
|
} else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) {
|
|
|
+ if (ctx->device->shader_int64 && ctx->device->buffer_device_address) {
|
|
|
+ // buffer device address path doesn't use dst buffer
|
|
|
+ d_sz = 1;
|
|
|
+ }
|
|
|
// im2col uses only src1 and dst buffers
|
|
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
|
|
} else if (op == GGML_OP_COUNT_EQUAL) {
|
|
|
@@ -9486,7 +9518,13 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
|
|
|
const uint32_t pelements = OW * KW * KH;
|
|
|
|
|
|
+ const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
|
|
+ const vk_buffer d_buf = d_buf_ctx->dev_buffer;
|
|
|
+
|
|
|
+ const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
|
|
|
+
|
|
|
ggml_vk_op_f32<vk_op_im2col_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, {
|
|
|
+ dst_addr,
|
|
|
batch_offset, offset_delta,
|
|
|
IC, IW, IH, OW, OH, KW, KH,
|
|
|
pelements,
|
|
|
@@ -9522,8 +9560,14 @@ static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
|
const int64_t OH = ne2;
|
|
|
const int64_t OW = ne1;
|
|
|
|
|
|
+ const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
|
|
+ const vk_buffer d_buf = d_buf_ctx->dev_buffer;
|
|
|
+
|
|
|
+ const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
|
|
|
+
|
|
|
vk_op_im2col_3d_push_constants pc {};
|
|
|
|
|
|
+ pc.dst_addr = dst_addr;
|
|
|
pc.nb10 = nb10 / ggml_type_size(src1->type);
|
|
|
pc.nb11 = nb11 / ggml_type_size(src1->type);
|
|
|
pc.nb12 = nb12 / ggml_type_size(src1->type);
|