|
@@ -570,6 +570,7 @@ struct vk_device_struct {
|
|
|
bool uma;
|
|
bool uma;
|
|
|
bool prefer_host_memory;
|
|
bool prefer_host_memory;
|
|
|
bool float_controls_rte_fp16;
|
|
bool float_controls_rte_fp16;
|
|
|
|
|
+ bool subgroup_basic;
|
|
|
bool subgroup_arithmetic;
|
|
bool subgroup_arithmetic;
|
|
|
bool subgroup_shuffle;
|
|
bool subgroup_shuffle;
|
|
|
bool subgroup_ballot;
|
|
bool subgroup_ballot;
|
|
@@ -4301,8 +4302,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
|
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
|
|
|
|
|
|
|
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
|
|
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
|
|
|
- ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
|
|
|
|
|
- ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
|
|
|
|
|
|
|
+ ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true);
|
|
|
|
|
+ ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true);
|
|
|
} else {
|
|
} else {
|
|
|
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
|
|
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
|
|
|
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
|
|
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
|
|
@@ -4638,6 +4639,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
}
|
|
}
|
|
|
device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
|
|
device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
|
|
|
|
|
|
|
|
|
|
+ device->subgroup_basic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
|
|
|
|
|
+ (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBasic);
|
|
|
device->subgroup_arithmetic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
|
|
device->subgroup_arithmetic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
|
|
|
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
|
|
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
|
|
|
#ifdef __APPLE__
|
|
#ifdef __APPLE__
|
|
@@ -9870,8 +9873,9 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
|
|
|
|
|
|
std::array<uint32_t, 3> elements;
|
|
std::array<uint32_t, 3> elements;
|
|
|
|
|
|
|
|
- const int splitH = 16;
|
|
|
|
|
- const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, splitH);
|
|
|
|
|
|
|
+ const uint32_t d_state = src0->ne[0];
|
|
|
|
|
+ uint32_t num_subgroups = d_state / ctx->device->subgroup_size;
|
|
|
|
|
+ const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, num_subgroups);
|
|
|
const uint32_t num_workgroups_y = n_seq;
|
|
const uint32_t num_workgroups_y = n_seq;
|
|
|
elements = { num_workgroups_x, num_workgroups_y, 1 };
|
|
elements = { num_workgroups_x, num_workgroups_y, 1 };
|
|
|
|
|
|
|
@@ -14777,11 +14781,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
return false;
|
|
return false;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- const uint32_t SPLIT_H = 16;
|
|
|
|
|
|
|
+ size_t shmem_size = d_state * sizeof(float);
|
|
|
|
|
|
|
|
- size_t stateC_size = SPLIT_H * d_state * sizeof(float);
|
|
|
|
|
|
|
+ if (shmem_size > device->properties.limits.maxComputeSharedMemorySize) {
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- if (stateC_size > device->properties.limits.maxComputeSharedMemorySize) {
|
|
|
|
|
|
|
+ if (!device->subgroup_basic) {
|
|
|
return false;
|
|
return false;
|
|
|
}
|
|
}
|
|
|
|
|
|