package matmul import ( "sync" "makarna/pkg/backend/cpu" ) const ( minWorkPerWorker = 8192 // heuristic on M*N size before fanning out ) // chooseWorkers returns a bounded worker count based on total work units and the // maximum allowed. func chooseWorkers(total, max int) int { if max < 1 { return 1 } if total <= minWorkPerWorker { return 1 } need := (total + minWorkPerWorker - 1) / minWorkPerWorker if need < 1 { need = 1 } if need > max { need = max } return need } // chunkRanges splits [0,total) into at most parts chunks. func chunkRanges(total, parts int) [][2]int { if parts < 1 { parts = 1 } chunk := (total + parts - 1) / parts if chunk < 1 { chunk = 1 } var ranges [][2]int for start := 0; start < total; start += chunk { end := start + chunk if end > total { end = total } ranges = append(ranges, [2]int{start, end}) } return ranges } // gemmFloat32Blocked computes C = A x B (row-major) where // A: MxK, B: NxK (row-major weights), C: MxN. // It uses a register-blocked 1x8 micro-kernel (when available) and // parallelizes across rows or columns depending on shape, without packing. func gemmFloat32Blocked(out, a, b []float32, M, K, N, maxWorkers int) { // Use an approximate MAC count to decide parallelism. Using only M*N can // underutilize CPU cores in decode (M==1) where K is large. total := M * N * K workers := chooseWorkers(total, maxWorkers) if workers == 1 { gemmFloat32Scalar(out, a, b, M, K, N) return } // Decode-path specialization: M == 1, split across N if M == 1 { ranges := chunkRanges(N, workers) var wg sync.WaitGroup for _, r := range ranges { wg.Add(1) start, end := r[0], r[1] go func(s, e int) { defer wg.Done() cpu.WithPinnedThread(func() { gemvFloat32Range(out, a[:K], b, K, s, e) }) }(start, end) } wg.Wait() return } if M < workers && N > 1 { ranges := chunkRanges(N, workers) var wg sync.WaitGroup for _, r := range ranges { wg.Add(1) start, end := r[0], r[1] go func(s, e int) { defer wg.Done() cpu.WithPinnedThread(func() { for m := 0; m < M; m++ { row := a[m*K : (m+1)*K] base := out[m*N : (m+1)*N] gemvFloat32Range(base, row, b, K, s, e) } }) }(start, end) } wg.Wait() return } ranges := chunkRanges(M, workers) var wg sync.WaitGroup for _, r := range ranges { wg.Add(1) start, end := r[0], r[1] go func(s, e int) { defer wg.Done() cpu.WithPinnedThread(func() { for m := s; m < e; m++ { row := a[m*K : (m+1)*K] base := out[m*N : (m+1)*N] gemvFloat32Range(base, row, b, K, 0, N) } }) }(start, end) } wg.Wait() } func gemmFloat32Scalar(out, a, b []float32, M, K, N int) { for m := 0; m < M; m++ { row := a[m*K : (m+1)*K] base := out[m*N : (m+1)*N] gemvFloat32Range(base, row, b, K, 0, N) } }