gemm_blocked.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. package matmul
  2. import (
  3. "sync"
  4. "makarna/pkg/backend/cpu"
  5. )
  6. const (
  7. minWorkPerWorker = 8192 // heuristic on M*N size before fanning out
  8. )
  9. // chooseWorkers returns a bounded worker count based on total work units and the
  10. // maximum allowed.
  11. func chooseWorkers(total, max int) int {
  12. if max < 1 {
  13. return 1
  14. }
  15. if total <= minWorkPerWorker {
  16. return 1
  17. }
  18. need := (total + minWorkPerWorker - 1) / minWorkPerWorker
  19. if need < 1 {
  20. need = 1
  21. }
  22. if need > max {
  23. need = max
  24. }
  25. return need
  26. }
  27. // chunkRanges splits [0,total) into at most parts chunks.
  28. func chunkRanges(total, parts int) [][2]int {
  29. if parts < 1 {
  30. parts = 1
  31. }
  32. chunk := (total + parts - 1) / parts
  33. if chunk < 1 {
  34. chunk = 1
  35. }
  36. var ranges [][2]int
  37. for start := 0; start < total; start += chunk {
  38. end := start + chunk
  39. if end > total {
  40. end = total
  41. }
  42. ranges = append(ranges, [2]int{start, end})
  43. }
  44. return ranges
  45. }
  46. // gemmFloat32Blocked computes C = A x B (row-major) where
  47. // A: MxK, B: NxK (row-major weights), C: MxN.
  48. // It uses a register-blocked 1x8 micro-kernel (when available) and
  49. // parallelizes across rows or columns depending on shape, without packing.
  50. func gemmFloat32Blocked(out, a, b []float32, M, K, N, maxWorkers int) {
  51. // Use an approximate MAC count to decide parallelism. Using only M*N can
  52. // underutilize CPU cores in decode (M==1) where K is large.
  53. total := M * N * K
  54. workers := chooseWorkers(total, maxWorkers)
  55. if workers == 1 {
  56. gemmFloat32Scalar(out, a, b, M, K, N)
  57. return
  58. }
  59. // Decode-path specialization: M == 1, split across N
  60. if M == 1 {
  61. ranges := chunkRanges(N, workers)
  62. var wg sync.WaitGroup
  63. for _, r := range ranges {
  64. wg.Add(1)
  65. start, end := r[0], r[1]
  66. go func(s, e int) {
  67. defer wg.Done()
  68. cpu.WithPinnedThread(func() {
  69. gemvFloat32Range(out, a[:K], b, K, s, e)
  70. })
  71. }(start, end)
  72. }
  73. wg.Wait()
  74. return
  75. }
  76. if M < workers && N > 1 {
  77. ranges := chunkRanges(N, workers)
  78. var wg sync.WaitGroup
  79. for _, r := range ranges {
  80. wg.Add(1)
  81. start, end := r[0], r[1]
  82. go func(s, e int) {
  83. defer wg.Done()
  84. cpu.WithPinnedThread(func() {
  85. for m := 0; m < M; m++ {
  86. row := a[m*K : (m+1)*K]
  87. base := out[m*N : (m+1)*N]
  88. gemvFloat32Range(base, row, b, K, s, e)
  89. }
  90. })
  91. }(start, end)
  92. }
  93. wg.Wait()
  94. return
  95. }
  96. ranges := chunkRanges(M, workers)
  97. var wg sync.WaitGroup
  98. for _, r := range ranges {
  99. wg.Add(1)
  100. start, end := r[0], r[1]
  101. go func(s, e int) {
  102. defer wg.Done()
  103. cpu.WithPinnedThread(func() {
  104. for m := s; m < e; m++ {
  105. row := a[m*K : (m+1)*K]
  106. base := out[m*N : (m+1)*N]
  107. gemvFloat32Range(base, row, b, K, 0, N)
  108. }
  109. })
  110. }(start, end)
  111. }
  112. wg.Wait()
  113. }
  114. func gemmFloat32Scalar(out, a, b []float32, M, K, N int) {
  115. for m := 0; m < M; m++ {
  116. row := a[m*K : (m+1)*K]
  117. base := out[m*N : (m+1)*N]
  118. gemvFloat32Range(base, row, b, K, 0, N)
  119. }
  120. }