|
|
@@ -386,10 +386,13 @@ struct vk_flash_attn_push_constants {
|
|
|
uint32_t nev3;
|
|
|
uint32_t nem1;
|
|
|
|
|
|
+ uint32_t nb01;
|
|
|
uint32_t nb02;
|
|
|
uint32_t nb03;
|
|
|
+ uint32_t nb11;
|
|
|
uint32_t nb12;
|
|
|
uint32_t nb13;
|
|
|
+ uint32_t nb21;
|
|
|
uint32_t nb22;
|
|
|
uint32_t nb23;
|
|
|
uint32_t nb31;
|
|
|
@@ -4809,7 +4812,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
}
|
|
|
assert(pipelines);
|
|
|
|
|
|
- bool aligned = (KV % pipelines[1]->align) == 0;
|
|
|
+ const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
|
|
|
+ const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
|
|
|
+ const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
|
|
|
+
|
|
|
+ bool aligned = (KV % pipelines[1]->align) == 0 &&
|
|
|
+ // the "aligned" shader variant will forcibly align strides, for performance
|
|
|
+ (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
|
|
|
+
|
|
|
vk_pipeline pipeline = pipelines[aligned];
|
|
|
assert(pipeline);
|
|
|
|
|
|
@@ -4845,15 +4855,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
|
|
|
if (ctx->device->uma) {
|
|
|
ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset);
|
|
|
- ggml_vk_host_get(ctx->device, k->data, d_K, q_buf_offset);
|
|
|
- ggml_vk_host_get(ctx->device, v->data, d_V, q_buf_offset);
|
|
|
- ggml_vk_host_get(ctx->device, dst->data, d_D, q_buf_offset);
|
|
|
+ ggml_vk_host_get(ctx->device, k->data, d_K, k_buf_offset);
|
|
|
+ ggml_vk_host_get(ctx->device, v->data, d_V, v_buf_offset);
|
|
|
+ ggml_vk_host_get(ctx->device, dst->data, d_D, d_buf_offset);
|
|
|
Q_uma = d_Q != nullptr;
|
|
|
K_uma = d_K != nullptr;
|
|
|
V_uma = d_V != nullptr;
|
|
|
D_uma = d_D != nullptr;
|
|
|
if (mask) {
|
|
|
- ggml_vk_host_get(ctx->device, mask->data, d_M, q_buf_offset);
|
|
|
+ ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset);
|
|
|
M_uma = d_M != nullptr;
|
|
|
}
|
|
|
}
|
|
|
@@ -4891,7 +4901,18 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, (uint32_t)neq2, (uint32_t)neq3, (uint32_t)nek2, (uint32_t)nek3, (uint32_t)nev2, (uint32_t)nev3, nem1, (uint32_t)nbq2, (uint32_t)nbq3, (uint32_t)nbk2, (uint32_t)nbk3, (uint32_t)nbv2, (uint32_t)nbv3, nbm1, scale, max_bias, logit_softcap, mask != nullptr, n_head_log2, m0, m1 };
|
|
|
+ const vk_flash_attn_push_constants pc = { N, KV,
|
|
|
+ (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
|
|
|
+ (uint32_t)neq2, (uint32_t)neq3,
|
|
|
+ (uint32_t)nek2, (uint32_t)nek3,
|
|
|
+ (uint32_t)nev2, (uint32_t)nev3,
|
|
|
+ nem1,
|
|
|
+ q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
|
|
|
+ k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
|
|
|
+ v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
|
|
|
+ nbm1,
|
|
|
+ scale, max_bias, logit_softcap,
|
|
|
+ mask != nullptr, n_head_log2, m0, m1 };
|
|
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
|
|
{
|
|
|
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
|
|
|
@@ -8668,6 +8689,7 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
|
|
|
ggml_tensor * src0 = tensor->src[0];
|
|
|
ggml_tensor * src1 = tensor->src[1];
|
|
|
ggml_tensor * src2 = tensor->src[2];
|
|
|
+ ggml_tensor * src3 = tensor->src[3];
|
|
|
|
|
|
void * tensor_data = tensor->data;
|
|
|
|
|
|
@@ -8730,6 +8752,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
|
|
|
if (src2 != nullptr) {
|
|
|
std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
|
|
|
}
|
|
|
+ if (src3 != nullptr) {
|
|
|
+ std::cerr << "src3=" << src3 << " src3->name=" << src3->name << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
|
|
|
+ }
|
|
|
std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
|
|
|
std::cerr << std::endl << "Result:" << std::endl;
|
|
|
ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3);
|
|
|
@@ -8774,6 +8799,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
|
|
|
if (src2 != nullptr) {
|
|
|
std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
|
|
|
}
|
|
|
+ if (src3 != nullptr) {
|
|
|
+ std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
|
|
|
+ }
|
|
|
std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
|
|
|
std::cerr << std::endl << "Result:" << std::endl;
|
|
|
ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0);
|
|
|
@@ -8796,6 +8824,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
|
|
|
if (src2 != nullptr) {
|
|
|
std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
|
|
|
}
|
|
|
+ if (src3 != nullptr) {
|
|
|
+ std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
|
|
|
+ }
|
|
|
std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
|
|
|
std::cerr << std::endl << "Result:" << std::endl;
|
|
|
ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]);
|