|
|
@@ -245,6 +245,7 @@ struct vk_device_struct {
|
|
|
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
|
|
|
vk_pipeline pipeline_timestep_embedding_f32;
|
|
|
vk_pipeline pipeline_pool2d_f32;
|
|
|
+ vk_pipeline pipeline_rwkv_wkv6_f32;
|
|
|
|
|
|
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
|
|
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
|
|
|
@@ -528,6 +529,13 @@ struct vk_op_pool2d_push_constants {
|
|
|
int32_t p0; int32_t p1;
|
|
|
};
|
|
|
|
|
|
+struct vk_op_rwkv_wkv6_push_constants {
|
|
|
+ uint32_t B;
|
|
|
+ uint32_t T;
|
|
|
+ uint32_t C;
|
|
|
+ uint32_t H;
|
|
|
+};
|
|
|
+
|
|
|
// Allow pre-recording command buffers
|
|
|
struct vk_staging_memcpy {
|
|
|
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
|
|
|
@@ -2014,6 +2022,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
|
|
|
ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);
|
|
|
|
|
|
+ ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
|
|
+
|
|
|
for (auto &c : compiles) {
|
|
|
c.wait();
|
|
|
}
|
|
|
@@ -5022,6 +5032,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
return ctx->device->pipeline_pool2d_f32;
|
|
|
}
|
|
|
return nullptr;
|
|
|
+ case GGML_OP_RWKV_WKV6:
|
|
|
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
|
+ return ctx->device->pipeline_rwkv_wkv6_f32;
|
|
|
+ }
|
|
|
+ return nullptr;
|
|
|
case GGML_OP_LEAKY_RELU:
|
|
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
|
return ctx->device->pipeline_leaky_relu_f32;
|
|
|
@@ -5424,6 +5439,134 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|
|
}, dryrun);
|
|
|
}
|
|
|
|
|
|
+static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) {
|
|
|
+ const ggml_tensor * k = dst->src[0];
|
|
|
+ const ggml_tensor * v = dst->src[1];
|
|
|
+ const ggml_tensor * r = dst->src[2];
|
|
|
+ const ggml_tensor * tf = dst->src[3];
|
|
|
+ const ggml_tensor * td = dst->src[4];
|
|
|
+ const ggml_tensor * state = dst->src[5];
|
|
|
+
|
|
|
+ GGML_ASSERT(!ggml_is_quantized(k->type));
|
|
|
+ GGML_ASSERT(!ggml_is_quantized(v->type));
|
|
|
+ GGML_ASSERT(!ggml_is_quantized(r->type));
|
|
|
+ GGML_ASSERT(!ggml_is_quantized(tf->type));
|
|
|
+ GGML_ASSERT(!ggml_is_quantized(td->type));
|
|
|
+ GGML_ASSERT(!ggml_is_quantized(state->type));
|
|
|
+ GGML_ASSERT(dst->buffer != nullptr);
|
|
|
+
|
|
|
+ vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6);
|
|
|
+ GGML_ASSERT(pipeline != nullptr);
|
|
|
+
|
|
|
+ if (dryrun) {
|
|
|
+ ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
|
|
+ ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context;
|
|
|
+ ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context;
|
|
|
+ ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context;
|
|
|
+ ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context;
|
|
|
+ ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context;
|
|
|
+ ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context;
|
|
|
+
|
|
|
+ ggml_vk_sync_buffers(subctx);
|
|
|
+
|
|
|
+ vk_buffer d_D, d_K, d_V, d_R, d_TF, d_TD, d_State;
|
|
|
+ uint64_t k_offset, v_offset, r_offset, tf_offset, td_offset, state_offset, dst_offset;
|
|
|
+ bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false;
|
|
|
+
|
|
|
+ if (ctx->device->uma) {
|
|
|
+ ggml_vk_host_get(ctx->device, k->data, d_K, k_offset);
|
|
|
+ ggml_vk_host_get(ctx->device, v->data, d_V, v_offset);
|
|
|
+ ggml_vk_host_get(ctx->device, r->data, d_R, r_offset);
|
|
|
+ ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset);
|
|
|
+ ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset);
|
|
|
+ ggml_vk_host_get(ctx->device, state->data, d_State, state_offset);
|
|
|
+ ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
|
|
|
+
|
|
|
+ K_uma = d_K != nullptr;
|
|
|
+ V_uma = d_V != nullptr;
|
|
|
+ R_uma = d_R != nullptr;
|
|
|
+ TF_uma = d_TF != nullptr;
|
|
|
+ TD_uma = d_TD != nullptr;
|
|
|
+ STATE_uma = d_State != nullptr;
|
|
|
+ DST_uma = d_D != nullptr;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (!K_uma) {
|
|
|
+ d_K = k_buf_ctx->dev_buffer;
|
|
|
+ k_offset = vk_tensor_offset(k) + k->view_offs;
|
|
|
+ }
|
|
|
+ if (!V_uma) {
|
|
|
+ d_V = v_buf_ctx->dev_buffer;
|
|
|
+ v_offset = vk_tensor_offset(v) + v->view_offs;
|
|
|
+ }
|
|
|
+ if (!R_uma) {
|
|
|
+ d_R = r_buf_ctx->dev_buffer;
|
|
|
+ r_offset = vk_tensor_offset(r) + r->view_offs;
|
|
|
+ }
|
|
|
+ if (!TF_uma) {
|
|
|
+ d_TF = tf_buf_ctx->dev_buffer;
|
|
|
+ tf_offset = vk_tensor_offset(tf) + tf->view_offs;
|
|
|
+ }
|
|
|
+ if (!TD_uma) {
|
|
|
+ d_TD = td_buf_ctx->dev_buffer;
|
|
|
+ td_offset = vk_tensor_offset(td) + td->view_offs;
|
|
|
+ }
|
|
|
+ if (!STATE_uma) {
|
|
|
+ d_State = state_buf_ctx->dev_buffer;
|
|
|
+ state_offset = vk_tensor_offset(state) + state->view_offs;
|
|
|
+ }
|
|
|
+ if (!DST_uma) {
|
|
|
+ d_D = dst_buf_ctx->dev_buffer;
|
|
|
+ dst_offset = vk_tensor_offset(dst) + dst->view_offs;
|
|
|
+ }
|
|
|
+
|
|
|
+ const uint64_t k_size = ggml_nbytes(k);
|
|
|
+ const uint64_t v_size = ggml_nbytes(v);
|
|
|
+ const uint64_t r_size = ggml_nbytes(r);
|
|
|
+ const uint64_t tf_size = ggml_nbytes(tf);
|
|
|
+ const uint64_t td_size = ggml_nbytes(td);
|
|
|
+ const uint64_t state_size = ggml_nbytes(state);
|
|
|
+ const uint64_t dst_size = ggml_nbytes(dst);
|
|
|
+
|
|
|
+ std::array<uint32_t, 3> elements = {
|
|
|
+ (uint32_t)(pc.B * pc.H),
|
|
|
+ 1,
|
|
|
+ 1
|
|
|
+ };
|
|
|
+
|
|
|
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
|
|
+ vk_subbuffer{ d_K, k_offset, k_size },
|
|
|
+ vk_subbuffer{ d_V, v_offset, v_size },
|
|
|
+ vk_subbuffer{ d_R, r_offset, r_size },
|
|
|
+ vk_subbuffer{ d_TF, tf_offset, tf_size },
|
|
|
+ vk_subbuffer{ d_TD, td_offset, td_size },
|
|
|
+ vk_subbuffer{ d_State, state_offset, state_size },
|
|
|
+ vk_subbuffer{ d_D, dst_offset, dst_size }
|
|
|
+ }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
|
|
|
+}
|
|
|
+
|
|
|
+static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
|
|
+ const size_t seq_length = dst->src[0]->ne[3];
|
|
|
+ const size_t n_embed = dst->ne[0];
|
|
|
+ const size_t n_heads = dst->src[0]->ne[2];
|
|
|
+ const size_t n_seqs = dst->src[5]->ne[1];
|
|
|
+
|
|
|
+ ggml_vk_op_f32_rwkv6(
|
|
|
+ ctx, subctx, dst,
|
|
|
+ {
|
|
|
+ (uint32_t)n_seqs,
|
|
|
+ (uint32_t)seq_length,
|
|
|
+ (uint32_t)n_embed,
|
|
|
+ (uint32_t)n_heads,
|
|
|
+ },
|
|
|
+ dryrun
|
|
|
+ );
|
|
|
+}
|
|
|
+
|
|
|
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
|
int * op_params = (int *)dst->op_params;
|
|
|
|
|
|
@@ -6569,6 +6712,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
case GGML_OP_IM2COL:
|
|
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
|
case GGML_OP_POOL_2D:
|
|
|
+ case GGML_OP_RWKV_WKV6:
|
|
|
case GGML_OP_LEAKY_RELU:
|
|
|
case GGML_OP_FLASH_ATTN_EXT:
|
|
|
break;
|
|
|
@@ -6768,6 +6912,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
case GGML_OP_FLASH_ATTN_EXT:
|
|
|
ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node, dryrun);
|
|
|
|
|
|
+ break;
|
|
|
+
|
|
|
+ case GGML_OP_RWKV_WKV6:
|
|
|
+ ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun);
|
|
|
+
|
|
|
break;
|
|
|
default:
|
|
|
return false;
|
|
|
@@ -6848,6 +6997,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
|
case GGML_OP_IM2COL:
|
|
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
|
case GGML_OP_POOL_2D:
|
|
|
+ case GGML_OP_RWKV_WKV6:
|
|
|
case GGML_OP_LEAKY_RELU:
|
|
|
case GGML_OP_REPEAT:
|
|
|
buf = tensor->buffer;
|
|
|
@@ -7724,6 +7874,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
case GGML_OP_IM2COL:
|
|
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
|
case GGML_OP_POOL_2D:
|
|
|
+ case GGML_OP_RWKV_WKV6:
|
|
|
case GGML_OP_LEAKY_RELU:
|
|
|
return true;
|
|
|
default:
|
|
|
@@ -8300,7 +8451,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
} else if (tensor->op == GGML_OP_LEAKY_RELU) {
|
|
|
const float * op_params = (const float *)tensor->op_params;
|
|
|
tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false);
|
|
|
- } else {
|
|
|
+ } else if (tensor->op == GGML_OP_RWKV_WKV6) {
|
|
|
+ tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3],
|
|
|
+ tensor->src[4], tensor->src[5]);
|
|
|
+ }
|
|
|
+ else {
|
|
|
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
|
|
GGML_ABORT("fatal error");
|
|
|
}
|