gemv_f32_tiled_amd64.go 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. //go:build amd64
  2. package matmul
  3. import "makarna/pkg/backend/cpu"
  4. const f32NR = 8
  5. // gemvFloat32Range computes out[startN:endN] = aRow * B where:
  6. // - aRow is length K
  7. // - B is NxK row-major (weights)
  8. // - out is length at least endN
  9. //
  10. // It prefers a register-blocked 1x8 micro-kernel on AVX2/AVX-512.
  11. func gemvFloat32Range(out, aRow, b []float32, K, startN, endN int) {
  12. if startN >= endN {
  13. return
  14. }
  15. if cpu.SupportsAVX512() && K >= 16 {
  16. gemvFloat32RangeAVX512(out, aRow, b, K, startN, endN)
  17. return
  18. }
  19. if cpu.SupportsAVX2() && K >= 8 {
  20. gemvFloat32RangeAVX2(out, aRow, b, K, startN, endN)
  21. return
  22. }
  23. bOff := startN * K
  24. for n := startN; n < endN; n++ {
  25. out[n] = cpu.DotFloat32(aRow, b[bOff:bOff+K])
  26. bOff += K
  27. }
  28. }
  29. func gemvFloat32RangeAVX2(out, aRow, b []float32, K, startN, endN int) {
  30. kMain := K &^ 7
  31. if kMain <= 0 {
  32. bOff := startN * K
  33. for n := startN; n < endN; n++ {
  34. out[n] = cpu.DotFloat32(aRow, b[bOff:bOff+K])
  35. bOff += K
  36. }
  37. return
  38. }
  39. aTail := aRow[kMain:K]
  40. kTail := K - kMain
  41. bOff := startN * K
  42. n := startN
  43. for ; n+f32NR <= endN; n += f32NR {
  44. gemvF32Tile8AVX2(&aRow[0], &b[bOff], &out[n], K)
  45. if kTail != 0 {
  46. for t := 0; t < f32NR; t++ {
  47. rowOff := bOff + t*K + kMain
  48. var tail float32
  49. for i := 0; i < kTail; i++ {
  50. tail += aTail[i] * b[rowOff+i]
  51. }
  52. out[n+t] += tail
  53. }
  54. }
  55. bOff += f32NR * K
  56. }
  57. for ; n < endN; n++ {
  58. out[n] = cpu.DotFloat32(aRow, b[bOff:bOff+K])
  59. bOff += K
  60. }
  61. }
  62. func gemvFloat32RangeAVX512(out, aRow, b []float32, K, startN, endN int) {
  63. kMain := K &^ 15
  64. if kMain <= 0 {
  65. gemvFloat32RangeAVX2(out, aRow, b, K, startN, endN)
  66. return
  67. }
  68. aTail := aRow[kMain:K]
  69. kTail := K - kMain
  70. bOff := startN * K
  71. n := startN
  72. for ; n+f32NR <= endN; n += f32NR {
  73. gemvF32Tile8AVX512(&aRow[0], &b[bOff], &out[n], K)
  74. if kTail != 0 {
  75. for t := 0; t < f32NR; t++ {
  76. rowOff := bOff + t*K + kMain
  77. var tail float32
  78. for i := 0; i < kTail; i++ {
  79. tail += aTail[i] * b[rowOff+i]
  80. }
  81. out[n+t] += tail
  82. }
  83. }
  84. bOff += f32NR * K
  85. }
  86. for ; n < endN; n++ {
  87. out[n] = cpu.DotFloat32(aRow, b[bOff:bOff+K])
  88. bOff += K
  89. }
  90. }
  91. //go:noescape
  92. func gemvF32Tile8AVX2(a *float32, b *float32, out *float32, K int)
  93. //go:noescape
  94. func gemvF32Tile8AVX512(a *float32, b *float32, out *float32, K int)