| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- //go:build amd64
- package matmul
- import "makarna/pkg/backend/cpu"
- const f32NR = 8
- // gemvFloat32Range computes out[startN:endN] = aRow * B where:
- // - aRow is length K
- // - B is NxK row-major (weights)
- // - out is length at least endN
- //
- // It prefers a register-blocked 1x8 micro-kernel on AVX2/AVX-512.
- func gemvFloat32Range(out, aRow, b []float32, K, startN, endN int) {
- if startN >= endN {
- return
- }
- if cpu.SupportsAVX512() && K >= 16 {
- gemvFloat32RangeAVX512(out, aRow, b, K, startN, endN)
- return
- }
- if cpu.SupportsAVX2() && K >= 8 {
- gemvFloat32RangeAVX2(out, aRow, b, K, startN, endN)
- return
- }
- bOff := startN * K
- for n := startN; n < endN; n++ {
- out[n] = cpu.DotFloat32(aRow, b[bOff:bOff+K])
- bOff += K
- }
- }
- func gemvFloat32RangeAVX2(out, aRow, b []float32, K, startN, endN int) {
- kMain := K &^ 7
- if kMain <= 0 {
- bOff := startN * K
- for n := startN; n < endN; n++ {
- out[n] = cpu.DotFloat32(aRow, b[bOff:bOff+K])
- bOff += K
- }
- return
- }
- aTail := aRow[kMain:K]
- kTail := K - kMain
- bOff := startN * K
- n := startN
- for ; n+f32NR <= endN; n += f32NR {
- gemvF32Tile8AVX2(&aRow[0], &b[bOff], &out[n], K)
- if kTail != 0 {
- for t := 0; t < f32NR; t++ {
- rowOff := bOff + t*K + kMain
- var tail float32
- for i := 0; i < kTail; i++ {
- tail += aTail[i] * b[rowOff+i]
- }
- out[n+t] += tail
- }
- }
- bOff += f32NR * K
- }
- for ; n < endN; n++ {
- out[n] = cpu.DotFloat32(aRow, b[bOff:bOff+K])
- bOff += K
- }
- }
- func gemvFloat32RangeAVX512(out, aRow, b []float32, K, startN, endN int) {
- kMain := K &^ 15
- if kMain <= 0 {
- gemvFloat32RangeAVX2(out, aRow, b, K, startN, endN)
- return
- }
- aTail := aRow[kMain:K]
- kTail := K - kMain
- bOff := startN * K
- n := startN
- for ; n+f32NR <= endN; n += f32NR {
- gemvF32Tile8AVX512(&aRow[0], &b[bOff], &out[n], K)
- if kTail != 0 {
- for t := 0; t < f32NR; t++ {
- rowOff := bOff + t*K + kMain
- var tail float32
- for i := 0; i < kTail; i++ {
- tail += aTail[i] * b[rowOff+i]
- }
- out[n+t] += tail
- }
- }
- bOff += f32NR * K
- }
- for ; n < endN; n++ {
- out[n] = cpu.DotFloat32(aRow, b[bOff:bOff+K])
- bOff += K
- }
- }
- //go:noescape
- func gemvF32Tile8AVX2(a *float32, b *float32, out *float32, K int)
- //go:noescape
- func gemvF32Tile8AVX512(a *float32, b *float32, out *float32, K int)
|