|
@@ -277,7 +277,7 @@ static void soft_max_f32_sycl(const float *x, const T *mask,
|
|
|
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(
|
|
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(
|
|
|
x, mask, sinks, dst, params, stream, block_dims, block_nums,
|
|
x, mask, sinks, dst, params, stream, block_dims, block_nums,
|