//go:build amd64 package matmul import "makarna/pkg/backend/cpu" const f32NR = 8 // gemvFloat32Range computes out[startN:endN] = aRow * B where: // - aRow is length K // - B is NxK row-major (weights) // - out is length at least endN // // It prefers a register-blocked 1x8 micro-kernel on AVX2/AVX-512. func gemvFloat32Range(out, aRow, b []float32, K, startN, endN int) { if startN >= endN { return } if cpu.SupportsAVX512() && K >= 16 { gemvFloat32RangeAVX512(out, aRow, b, K, startN, endN) return } if cpu.SupportsAVX2() && K >= 8 { gemvFloat32RangeAVX2(out, aRow, b, K, startN, endN) return } bOff := startN * K for n := startN; n < endN; n++ { out[n] = cpu.DotFloat32(aRow, b[bOff:bOff+K]) bOff += K } } func gemvFloat32RangeAVX2(out, aRow, b []float32, K, startN, endN int) { kMain := K &^ 7 if kMain <= 0 { bOff := startN * K for n := startN; n < endN; n++ { out[n] = cpu.DotFloat32(aRow, b[bOff:bOff+K]) bOff += K } return } aTail := aRow[kMain:K] kTail := K - kMain bOff := startN * K n := startN for ; n+f32NR <= endN; n += f32NR { gemvF32Tile8AVX2(&aRow[0], &b[bOff], &out[n], K) if kTail != 0 { for t := 0; t < f32NR; t++ { rowOff := bOff + t*K + kMain var tail float32 for i := 0; i < kTail; i++ { tail += aTail[i] * b[rowOff+i] } out[n+t] += tail } } bOff += f32NR * K } for ; n < endN; n++ { out[n] = cpu.DotFloat32(aRow, b[bOff:bOff+K]) bOff += K } } func gemvFloat32RangeAVX512(out, aRow, b []float32, K, startN, endN int) { kMain := K &^ 15 if kMain <= 0 { gemvFloat32RangeAVX2(out, aRow, b, K, startN, endN) return } aTail := aRow[kMain:K] kTail := K - kMain bOff := startN * K n := startN for ; n+f32NR <= endN; n += f32NR { gemvF32Tile8AVX512(&aRow[0], &b[bOff], &out[n], K) if kTail != 0 { for t := 0; t < f32NR; t++ { rowOff := bOff + t*K + kMain var tail float32 for i := 0; i < kTail; i++ { tail += aTail[i] * b[rowOff+i] } out[n+t] += tail } } bOff += f32NR * K } for ; n < endN; n++ { out[n] = cpu.DotFloat32(aRow, b[bOff:bOff+K]) bOff += K } } //go:noescape func gemvF32Tile8AVX2(a *float32, b *float32, out *float32, K int) //go:noescape func gemvF32Tile8AVX512(a *float32, b *float32, out *float32, K int)