|
|
@@ -530,8 +530,8 @@ struct vk_device_struct {
|
|
|
vk_pipeline pipeline_opt_step_sgd_f32;
|
|
|
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
|
|
|
vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
|
|
|
- vk_pipeline pipeline_conv2d_dw_whcn_f32;
|
|
|
- vk_pipeline pipeline_conv2d_dw_cwhn_f32;
|
|
|
+ vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
|
|
|
+ vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
|
|
|
|
|
|
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
|
|
vk_pipeline pipeline_flash_attn_f32_f16_cm2[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
|
|
|
@@ -3257,6 +3257,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
|
|
|
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
|
|
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
|
|
+ ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
|
|
+ ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
|
|
|
|
|
for (auto &c : compiles) {
|
|
|
c.wait();
|
|
|
@@ -7346,6 +7348,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
} else if (ggml_is_contiguous_channels(src1)) {
|
|
|
return ctx->device->pipeline_conv2d_dw_cwhn_f32;
|
|
|
}
|
|
|
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
|
|
|
+ if (ggml_is_contiguous(src1)) {
|
|
|
+ return ctx->device->pipeline_conv2d_dw_whcn_f16_f32;
|
|
|
+ } else if (ggml_is_contiguous_channels(src1)) {
|
|
|
+ return ctx->device->pipeline_conv2d_dw_cwhn_f16_f32;
|
|
|
+ }
|
|
|
}
|
|
|
return nullptr;
|
|
|
default:
|