@@ -256,6 +256,9 @@ void main() {
barrier();
}
+ // prevent race on tmpsh
+ barrier();
+
// reduce across threads
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
@@ -302,6 +302,9 @@ void main() {
float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread];