|
@@ -3738,6 +3738,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
|
|
|
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_I32], "get_rows_i32", get_rows_i32_len, get_rows_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
|
|
|
|
|
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
|
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
|
@@ -8294,6 +8295,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
switch (op) {
|
|
switch (op) {
|
|
|
case GGML_OP_GET_ROWS:
|
|
case GGML_OP_GET_ROWS:
|
|
|
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
|
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
|
|
|
|
+ if (src0->type == GGML_TYPE_I32) {
|
|
|
|
|
+ // i32 src only supports i32 result
|
|
|
|
|
+ GGML_ASSERT(dst->type == GGML_TYPE_I32);
|
|
|
|
|
+ return ctx->device->pipeline_get_rows[src0->type];
|
|
|
|
|
+ }
|
|
|
if (dst->type == GGML_TYPE_F16) {
|
|
if (dst->type == GGML_TYPE_F16) {
|
|
|
return ctx->device->pipeline_get_rows[src0->type];
|
|
return ctx->device->pipeline_get_rows[src0->type];
|
|
|
}
|
|
}
|
|
@@ -13964,6 +13970,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
case GGML_TYPE_IQ4_XS:
|
|
case GGML_TYPE_IQ4_XS:
|
|
|
case GGML_TYPE_IQ4_NL:
|
|
case GGML_TYPE_IQ4_NL:
|
|
|
case GGML_TYPE_MXFP4:
|
|
case GGML_TYPE_MXFP4:
|
|
|
|
|
+ case GGML_TYPE_I32:
|
|
|
return true;
|
|
return true;
|
|
|
default:
|
|
default:
|
|
|
return false;
|
|
return false;
|