| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 |
- package matmul
- import (
- "sync"
- "makarna/pkg/backend/cpu"
- )
- const (
- minWorkPerWorker = 8192 // heuristic on M*N size before fanning out
- )
- // chooseWorkers returns a bounded worker count based on total work units and the
- // maximum allowed.
- func chooseWorkers(total, max int) int {
- if max < 1 {
- return 1
- }
- if total <= minWorkPerWorker {
- return 1
- }
- need := (total + minWorkPerWorker - 1) / minWorkPerWorker
- if need < 1 {
- need = 1
- }
- if need > max {
- need = max
- }
- return need
- }
- // chunkRanges splits [0,total) into at most parts chunks.
- func chunkRanges(total, parts int) [][2]int {
- if parts < 1 {
- parts = 1
- }
- chunk := (total + parts - 1) / parts
- if chunk < 1 {
- chunk = 1
- }
- var ranges [][2]int
- for start := 0; start < total; start += chunk {
- end := start + chunk
- if end > total {
- end = total
- }
- ranges = append(ranges, [2]int{start, end})
- }
- return ranges
- }
- // gemmFloat32Blocked computes C = A x B (row-major) where
- // A: MxK, B: NxK (row-major weights), C: MxN.
- // It uses a register-blocked 1x8 micro-kernel (when available) and
- // parallelizes across rows or columns depending on shape, without packing.
- func gemmFloat32Blocked(out, a, b []float32, M, K, N, maxWorkers int) {
- // Use an approximate MAC count to decide parallelism. Using only M*N can
- // underutilize CPU cores in decode (M==1) where K is large.
- total := M * N * K
- workers := chooseWorkers(total, maxWorkers)
- if workers == 1 {
- gemmFloat32Scalar(out, a, b, M, K, N)
- return
- }
- // Decode-path specialization: M == 1, split across N
- if M == 1 {
- ranges := chunkRanges(N, workers)
- var wg sync.WaitGroup
- for _, r := range ranges {
- wg.Add(1)
- start, end := r[0], r[1]
- go func(s, e int) {
- defer wg.Done()
- cpu.WithPinnedThread(func() {
- gemvFloat32Range(out, a[:K], b, K, s, e)
- })
- }(start, end)
- }
- wg.Wait()
- return
- }
- if M < workers && N > 1 {
- ranges := chunkRanges(N, workers)
- var wg sync.WaitGroup
- for _, r := range ranges {
- wg.Add(1)
- start, end := r[0], r[1]
- go func(s, e int) {
- defer wg.Done()
- cpu.WithPinnedThread(func() {
- for m := 0; m < M; m++ {
- row := a[m*K : (m+1)*K]
- base := out[m*N : (m+1)*N]
- gemvFloat32Range(base, row, b, K, s, e)
- }
- })
- }(start, end)
- }
- wg.Wait()
- return
- }
- ranges := chunkRanges(M, workers)
- var wg sync.WaitGroup
- for _, r := range ranges {
- wg.Add(1)
- start, end := r[0], r[1]
- go func(s, e int) {
- defer wg.Done()
- cpu.WithPinnedThread(func() {
- for m := s; m < e; m++ {
- row := a[m*K : (m+1)*K]
- base := out[m*N : (m+1)*N]
- gemvFloat32Range(base, row, b, K, 0, N)
- }
- })
- }(start, end)
- }
- wg.Wait()
- }
- func gemmFloat32Scalar(out, a, b []float32, M, K, N int) {
- for m := 0; m < M; m++ {
- row := a[m*K : (m+1)*K]
- base := out[m*N : (m+1)*N]
- gemvFloat32Range(base, row, b, K, 0, N)
- }
- }
|