|
|
@@ -1261,6 +1261,7 @@ struct vk_op_im2col_push_constants {
|
|
|
int32_t s0; int32_t s1;
|
|
|
int32_t p0; int32_t p1;
|
|
|
int32_t d0; int32_t d1;
|
|
|
+ uint32_t batch_IC;
|
|
|
};
|
|
|
|
|
|
struct vk_op_im2col_3d_push_constants {
|
|
|
@@ -5902,6 +5903,9 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
|
|
|
std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), ";
|
|
|
}
|
|
|
std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
|
|
|
+ GGML_ASSERT(wg0 <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
|
|
|
+ wg1 <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
|
|
|
+ wg2 <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
|
|
|
GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
|
|
|
GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
|
|
|
GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());
|
|
|
@@ -9090,6 +9094,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
const uint32_t batch = src1->ne[is_2D ? 3 : 2];
|
|
|
|
|
|
elements = { OW * KW * KH, OH, batch * IC };
|
|
|
+ elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
|
|
+ elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
|
|
|
} break;
|
|
|
case GGML_OP_IM2COL_3D:
|
|
|
{
|
|
|
@@ -10605,6 +10611,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
|
|
|
|
|
|
const uint32_t pelements = OW * KW * KH;
|
|
|
+ const uint32_t batch = src1->ne[is_2D ? 3 : 2];
|
|
|
|
|
|
const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
|
|
const vk_buffer d_buf = d_buf_ctx->dev_buffer;
|
|
|
@@ -10617,7 +10624,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
IC, IW, IH, OW, OH, KW, KH,
|
|
|
pelements,
|
|
|
IC * KH * KW,
|
|
|
- s0, s1, p0, p1, d0, d1,
|
|
|
+ s0, s1, p0, p1, d0, d1, batch * IC
|
|
|
});
|
|
|
}
|
|
|
|