|
@@ -218,6 +218,7 @@ struct vk_device_struct {
|
|
|
vk_pipeline pipeline_tanh_f32;
|
|
vk_pipeline pipeline_tanh_f32;
|
|
|
vk_pipeline pipeline_diag_mask_inf_f32;
|
|
vk_pipeline pipeline_diag_mask_inf_f32;
|
|
|
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
|
|
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
|
|
|
|
|
+ vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
|
|
|
vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
|
|
vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
|
|
|
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
|
|
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
|
|
|
vk_pipeline pipeline_argsort_f32;
|
|
vk_pipeline pipeline_argsort_f32;
|
|
@@ -388,6 +389,7 @@ struct vk_op_soft_max_push_constants {
|
|
|
float m0;
|
|
float m0;
|
|
|
float m1;
|
|
float m1;
|
|
|
uint32_t n_head_log2;
|
|
uint32_t n_head_log2;
|
|
|
|
|
+ uint32_t nrows_x;
|
|
|
};
|
|
};
|
|
|
|
|
|
|
|
struct vk_op_argsort_push_constants {
|
|
struct vk_op_argsort_push_constants {
|
|
@@ -1497,8 +1499,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
|
|
|
|
|
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1);
|
|
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1);
|
|
|
|
|
|
|
|
- ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1);
|
|
|
|
|
- ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1);
|
|
|
|
|
|
|
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
|
|
|
|
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
|
|
|
|
|
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
|
|
|
|
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
|
|
|
|
|
|
|
|
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
|
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
@@ -3932,10 +3936,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
|
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
|
|
|
|
|
|
|
|
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
|
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
|
|
- return ctx->device->pipeline_soft_max_f32;
|
|
|
|
|
|
|
+ return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
|
|
|
}
|
|
}
|
|
|
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
|
|
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
|
|
|
- return ctx->device->pipeline_soft_max_f32_f16;
|
|
|
|
|
|
|
+ return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16;
|
|
|
}
|
|
}
|
|
|
return nullptr;
|
|
return nullptr;
|
|
|
case GGML_OP_ROPE:
|
|
case GGML_OP_ROPE:
|
|
@@ -4581,6 +4585,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
|
scale, max_bias,
|
|
scale, max_bias,
|
|
|
m0, m1,
|
|
m0, m1,
|
|
|
n_head_log2,
|
|
n_head_log2,
|
|
|
|
|
+ nrows_x,
|
|
|
}, dryrun);
|
|
}, dryrun);
|
|
|
}
|
|
}
|
|
|
|
|
|