|
|
@@ -330,9 +330,11 @@ void main() {
|
|
|
// resize eM by using smear/reduce
|
|
|
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
|
|
|
|
|
- O = eMdiag * O;
|
|
|
+ // multiply with fp16 accumulation, then add to O.
|
|
|
+ coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
|
|
|
+ PV = coopMatMulAdd(P_A, V, PV);
|
|
|
|
|
|
- O = coopMatMulAdd(P_A, V, O);
|
|
|
+ O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(PV);
|
|
|
}
|
|
|
|
|
|
// If there is split_k, then the split_k resolve shader does the final
|