|
|
@@ -249,13 +249,16 @@ void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
|
|
|
|
if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F16) {
|
|
|
const sycl::half * src1_dd = static_cast<sycl::half *>(dst->src[1]->data);
|
|
|
+ GGML_SYCL_DEBUG("%s: F16 mask\n", __func__);
|
|
|
soft_max_f32_sycl<sycl::half>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias,
|
|
|
main_stream, ctx.device);
|
|
|
} else if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F32) {
|
|
|
const float * src1_dd = static_cast<const float *>(dst->src[1]->data);
|
|
|
+ GGML_SYCL_DEBUG("%s: F32 mask\n", __func__);
|
|
|
soft_max_f32_sycl<float>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
|
|
|
} else {
|
|
|
/* mask unavailable */
|
|
|
+ GGML_SYCL_DEBUG("%s: No mask\n", __func__);
|
|
|
soft_max_f32_sycl<float>(src0_dd, nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
|
|
|
}
|
|
|
}
|