package matmul import ( "fmt" "sync" "makarna/pkg/backend/cpu" "makarna/pkg/tensor" ) // linearCPU contains the original CPU implementations for all supported // weight dtypes. Both CPU-only and CUDA-enabled builds reuse this. func linearCPU(input, weight, output *cpu.Tensor) error { inShape := input.Shape() wShape := weight.Shape() // Validate dimensions if len(inShape) != 2 || len(wShape) != 2 { return fmt.Errorf("linear: expected 2D inputs, got input %v, weight %v", inShape, wShape) } M := inShape[0] K := inShape[1] N := wShape[0] if wShape[1] != K { return fmt.Errorf("linear: shape mismatch: input [*, %d] vs weight [%d, %d]", K, N, wShape[1]) } inData := input.DataFloat32() outData := output.DataFloat32() workers := cpu.MaxThreads() switch weight.DType() { case tensor.Float32: wData := weight.DataFloat32() gemmFloat32Blocked(outData, inData, wData, M, K, N, workers) case tensor.Q4_K: wData := weight.DataQ4_K() if K%256 != 0 { return fmt.Errorf("linear: Q4_K weight K dimension %d must be multiple of 256", K) } wParams := tensor.GetQ4KDotParams(wData) blocksPerRow := K / 256 work := M * N * K use := chooseWorkers(work, workers) if use == 1 { if M == 1 { q4kGemvDecodeTiled(outData[:N], inData[:K], wData, wParams, N, blocksPerRow, 0, N) return nil } for m := 0; m < M; m++ { for n := 0; n < N; n++ { var sum float32 for b := 0; b < blocksPerRow; b++ { inOffset := m*K + b*256 wBlockIdx := n*blocksPerRow + b block := &wData[wBlockIdx] p := &wParams[wBlockIdx] sum += tensor.DotQ4_K_Params(block, p, inData[inOffset:inOffset+256]) } outData[m*N+n] = sum } } return nil } var wg sync.WaitGroup if M == 1 { for _, r := range chunkRanges(N, use) { wg.Add(1) start, end := r[0], r[1] go func(s, e int) { defer wg.Done() q4kGemvDecodeTiled(outData[:N], inData[:K], wData, wParams, N, blocksPerRow, s, e) }(start, end) } wg.Wait() return nil } if M < use { for _, r := range chunkRanges(N, use) { wg.Add(1) start, end := r[0], r[1] go func(s, e int) { defer wg.Done() for n := s; n < e; n++ { for m := 0; m < M; m++ { var sum float32 for b := 0; b < blocksPerRow; b++ { inOffset := m*K + b*256 wBlockIdx := n*blocksPerRow + b block := &wData[wBlockIdx] p := &wParams[wBlockIdx] sum += tensor.DotQ4_K_Params(block, p, inData[inOffset:inOffset+256]) } outData[m*N+n] = sum } } }(start, end) } wg.Wait() return nil } for _, r := range chunkRanges(M, use) { wg.Add(1) start, end := r[0], r[1] go func(s, e int) { defer wg.Done() for m := s; m < e; m++ { for n := 0; n < N; n++ { var sum float32 for b := 0; b < blocksPerRow; b++ { inOffset := m*K + b*256 wBlockIdx := n*blocksPerRow + b block := &wData[wBlockIdx] p := &wParams[wBlockIdx] sum += tensor.DotQ4_K_Params(block, p, inData[inOffset:inOffset+256]) } outData[m*N+n] = sum } } }(start, end) } wg.Wait() case tensor.Q8_K: wData := weight.DataQ8_K() if K%256 != 0 { return fmt.Errorf("linear: Q8_K weight K dimension %d must be multiple of 256", K) } blocksPerRow := K / 256 work := M * N * K use := chooseWorkers(work, workers) if use == 1 { if M == 1 { q8kGemvDecodeTiled(outData[:N], inData[:K], wData, N, blocksPerRow, 0, N) return nil } for m := 0; m < M; m++ { for n := 0; n < N; n++ { var sum float32 for b := 0; b < blocksPerRow; b++ { inOffset := m*K + b*256 wBlockIdx := n*blocksPerRow + b block := &wData[wBlockIdx] sum += tensor.DotQ8_K(block, inData[inOffset:inOffset+256]) } outData[m*N+n] = sum } } return nil } var wg sync.WaitGroup if M == 1 { for _, r := range chunkRanges(N, use) { wg.Add(1) start, end := r[0], r[1] go func(s, e int) { defer wg.Done() q8kGemvDecodeTiled(outData[:N], inData[:K], wData, N, blocksPerRow, s, e) }(start, end) } wg.Wait() return nil } if M < use { for _, r := range chunkRanges(N, use) { wg.Add(1) start, end := r[0], r[1] go func(s, e int) { defer wg.Done() for n := s; n < e; n++ { for m := 0; m < M; m++ { var sum float32 for b := 0; b < blocksPerRow; b++ { inOffset := m*K + b*256 wBlockIdx := n*blocksPerRow + b block := &wData[wBlockIdx] sum += tensor.DotQ8_K(block, inData[inOffset:inOffset+256]) } outData[m*N+n] = sum } } }(start, end) } wg.Wait() return nil } for _, r := range chunkRanges(M, use) { wg.Add(1) start, end := r[0], r[1] go func(s, e int) { defer wg.Done() for m := s; m < e; m++ { for n := 0; n < N; n++ { var sum float32 for b := 0; b < blocksPerRow; b++ { inOffset := m*K + b*256 wBlockIdx := n*blocksPerRow + b block := &wData[wBlockIdx] sum += tensor.DotQ8_K(block, inData[inOffset:inOffset+256]) } outData[m*N+n] = sum } } }(start, end) } wg.Wait() case tensor.Q3_K: wData := weight.DataQ3_K() if K%256 != 0 { return fmt.Errorf("linear: Q3_K weight K dimension %d must be multiple of 256", K) } wParams := tensor.GetQ3KDotParams(wData) blocksPerRow := K / 256 work := M * N * K use := chooseWorkers(work, workers) if use == 1 { if M == 1 { q3kGemvDecodeTiled(outData[:N], inData[:K], wData, wParams, N, blocksPerRow, 0, N) return nil } for m := 0; m < M; m++ { for n := 0; n < N; n++ { var sum float32 for b := 0; b < blocksPerRow; b++ { inOffset := m*K + b*256 wBlockIdx := n*blocksPerRow + b block := &wData[wBlockIdx] p := &wParams[wBlockIdx] sum += tensor.DotQ3_K_Params(block, p, inData[inOffset:inOffset+256]) } outData[m*N+n] = sum } } return nil } var wg sync.WaitGroup if M == 1 { for _, r := range chunkRanges(N, use) { wg.Add(1) start, end := r[0], r[1] go func(s, e int) { defer wg.Done() q3kGemvDecodeTiled(outData[:N], inData[:K], wData, wParams, N, blocksPerRow, s, e) }(start, end) } wg.Wait() return nil } if M < use { for _, r := range chunkRanges(N, use) { wg.Add(1) start, end := r[0], r[1] go func(s, e int) { defer wg.Done() for n := s; n < e; n++ { for m := 0; m < M; m++ { var sum float32 for b := 0; b < blocksPerRow; b++ { inOffset := m*K + b*256 wBlockIdx := n*blocksPerRow + b block := &wData[wBlockIdx] p := &wParams[wBlockIdx] sum += tensor.DotQ3_K_Params(block, p, inData[inOffset:inOffset+256]) } outData[m*N+n] = sum } } }(start, end) } wg.Wait() return nil } for _, r := range chunkRanges(M, use) { wg.Add(1) start, end := r[0], r[1] go func(s, e int) { defer wg.Done() for m := s; m < e; m++ { for n := 0; n < N; n++ { var sum float32 for b := 0; b < blocksPerRow; b++ { inOffset := m*K + b*256 wBlockIdx := n*blocksPerRow + b block := &wData[wBlockIdx] p := &wParams[wBlockIdx] sum += tensor.DotQ3_K_Params(block, p, inData[inOffset:inOffset+256]) } outData[m*N+n] = sum } } }(start, end) } wg.Wait() case tensor.Q5_K: wData := weight.DataQ5_K() if K%256 != 0 { return fmt.Errorf("linear: Q5_K weight K dimension %d must be multiple of 256", K) } wParams := tensor.GetQ5KDotParams(wData) blocksPerRow := K / 256 work := M * N * K use := chooseWorkers(work, workers) if use == 1 { if M == 1 { q5kGemvDecodeTiled(outData[:N], inData[:K], wData, wParams, N, blocksPerRow, 0, N) return nil } for m := 0; m < M; m++ { for n := 0; n < N; n++ { var sum float32 for b := 0; b < blocksPerRow; b++ { inOffset := m*K + b*256 wBlockIdx := n*blocksPerRow + b block := &wData[wBlockIdx] p := &wParams[wBlockIdx] sum += tensor.DotQ5_K_Params(block, p, inData[inOffset:inOffset+256]) } outData[m*N+n] = sum } } return nil } var wg sync.WaitGroup if M == 1 { for _, r := range chunkRanges(N, use) { wg.Add(1) start, end := r[0], r[1] go func(s, e int) { defer wg.Done() q5kGemvDecodeTiled(outData[:N], inData[:K], wData, wParams, N, blocksPerRow, s, e) }(start, end) } wg.Wait() return nil } if M < use { for _, r := range chunkRanges(N, use) { wg.Add(1) start, end := r[0], r[1] go func(s, e int) { defer wg.Done() for n := s; n < e; n++ { for m := 0; m < M; m++ { var sum float32 for b := 0; b < blocksPerRow; b++ { inOffset := m*K + b*256 wBlockIdx := n*blocksPerRow + b block := &wData[wBlockIdx] p := &wParams[wBlockIdx] sum += tensor.DotQ5_K_Params(block, p, inData[inOffset:inOffset+256]) } outData[m*N+n] = sum } } }(start, end) } wg.Wait() return nil } for _, r := range chunkRanges(M, use) { wg.Add(1) start, end := r[0], r[1] go func(s, e int) { defer wg.Done() for m := s; m < e; m++ { for n := 0; n < N; n++ { var sum float32 for b := 0; b < blocksPerRow; b++ { inOffset := m*K + b*256 wBlockIdx := n*blocksPerRow + b block := &wData[wBlockIdx] p := &wParams[wBlockIdx] sum += tensor.DotQ5_K_Params(block, p, inData[inOffset:inOffset+256]) } outData[m*N+n] = sum } } }(start, end) } wg.Wait() case tensor.Q6_K: wData := weight.DataQ6_K() if K%256 != 0 { return fmt.Errorf("linear: Q6_K weight K dimension %d must be multiple of 256", K) } wParams := tensor.GetQ6KDotParams(wData) blocksPerRow := K / 256 work := M * N * K use := chooseWorkers(work, workers) if use == 1 { if M == 1 { q6kGemvDecodeTiled(outData[:N], inData[:K], wData, wParams, N, blocksPerRow, 0, N) return nil } for m := 0; m < M; m++ { for n := 0; n < N; n++ { var sum float32 for b := 0; b < blocksPerRow; b++ { inOffset := m*K + b*256 wBlockIdx := n*blocksPerRow + b block := &wData[wBlockIdx] p := &wParams[wBlockIdx] sum += tensor.DotQ6_K_Params(block, p, inData[inOffset:inOffset+256]) } outData[m*N+n] = sum } } return nil } var wg sync.WaitGroup if M == 1 { for _, r := range chunkRanges(N, use) { wg.Add(1) start, end := r[0], r[1] go func(s, e int) { defer wg.Done() q6kGemvDecodeTiled(outData[:N], inData[:K], wData, wParams, N, blocksPerRow, s, e) }(start, end) } wg.Wait() return nil } if M < use { for _, r := range chunkRanges(N, use) { wg.Add(1) start, end := r[0], r[1] go func(s, e int) { defer wg.Done() for n := s; n < e; n++ { for m := 0; m < M; m++ { var sum float32 for b := 0; b < blocksPerRow; b++ { inOffset := m*K + b*256 wBlockIdx := n*blocksPerRow + b block := &wData[wBlockIdx] p := &wParams[wBlockIdx] sum += tensor.DotQ6_K_Params(block, p, inData[inOffset:inOffset+256]) } outData[m*N+n] = sum } } }(start, end) } wg.Wait() return nil } for _, r := range chunkRanges(M, use) { wg.Add(1) start, end := r[0], r[1] go func(s, e int) { defer wg.Done() for m := s; m < e; m++ { for n := 0; n < N; n++ { var sum float32 for b := 0; b < blocksPerRow; b++ { inOffset := m*K + b*256 wBlockIdx := n*blocksPerRow + b block := &wData[wBlockIdx] p := &wParams[wBlockIdx] sum += tensor.DotQ6_K_Params(block, p, inData[inOffset:inOffset+256]) } outData[m*N+n] = sum } } }(start, end) } wg.Wait() case tensor.Q2_K: wData := weight.DataQ2_K() if K%256 != 0 { return fmt.Errorf("linear: Q2_K weight K dimension %d must be multiple of 256", K) } wParams := tensor.GetQ2KDotParams(wData) blocksPerRow := K / 256 work := M * N * K use := chooseWorkers(work, workers) if use == 1 { if M == 1 { q2kGemvDecodeTiled(outData[:N], inData[:K], wData, wParams, N, blocksPerRow, 0, N) return nil } for m := 0; m < M; m++ { for n := 0; n < N; n++ { var sum float32 for b := 0; b < blocksPerRow; b++ { inOffset := m*K + b*256 wBlockIdx := n*blocksPerRow + b block := &wData[wBlockIdx] p := &wParams[wBlockIdx] sum += tensor.DotQ2_K_Params(block, p, inData[inOffset:inOffset+256]) } outData[m*N+n] = sum } } return nil } var wg sync.WaitGroup if M == 1 { for _, r := range chunkRanges(N, use) { wg.Add(1) start, end := r[0], r[1] go func(s, e int) { defer wg.Done() q2kGemvDecodeTiled(outData[:N], inData[:K], wData, wParams, N, blocksPerRow, s, e) }(start, end) } wg.Wait() return nil } if M < use { for _, r := range chunkRanges(N, use) { wg.Add(1) start, end := r[0], r[1] go func(s, e int) { defer wg.Done() for n := s; n < e; n++ { for m := 0; m < M; m++ { var sum float32 for b := 0; b < blocksPerRow; b++ { inOffset := m*K + b*256 wBlockIdx := n*blocksPerRow + b block := &wData[wBlockIdx] p := &wParams[wBlockIdx] sum += tensor.DotQ2_K_Params(block, p, inData[inOffset:inOffset+256]) } outData[m*N+n] = sum } } }(start, end) } wg.Wait() return nil } for _, r := range chunkRanges(M, use) { wg.Add(1) start, end := r[0], r[1] go func(s, e int) { defer wg.Done() for m := s; m < e; m++ { for n := 0; n < N; n++ { var sum float32 for b := 0; b < blocksPerRow; b++ { inOffset := m*K + b*256 wBlockIdx := n*blocksPerRow + b block := &wData[wBlockIdx] p := &wParams[wBlockIdx] sum += tensor.DotQ2_K_Params(block, p, inData[inOffset:inOffset+256]) } outData[m*N+n] = sum } } }(start, end) } wg.Wait() default: return fmt.Errorf("linear: unsupported weight dtype %v", weight.DType()) } return nil } func q4kGemvDecodeTiled(out []float32, x []float32, w []tensor.BlockQ4_K, wp []tensor.Q4KDotParams, N, blocksPerRow, startN, endN int) { const tile = 8 for n := startN; n < endN; n += tile { tn := endN - n if tn > tile { tn = tile } var sums [tile]float32 for b := 0; b < blocksPerRow; b++ { xBlock := &x[b*256] base := n*blocksPerRow + b tensor.DotQ4KTile8(&sums, w, wp, base, blocksPerRow, xBlock, tn) } for t := 0; t < tn; t++ { out[n+t] = sums[t] } } } func q5kGemvDecodeTiled(out []float32, x []float32, w []tensor.BlockQ5_K, wp []tensor.Q5KDotParams, N, blocksPerRow, startN, endN int) { const tile = 8 for n := startN; n < endN; n += tile { tn := endN - n if tn > tile { tn = tile } var sums [tile]float32 for b := 0; b < blocksPerRow; b++ { xBlock := &x[b*256] base := n*blocksPerRow + b tensor.DotQ5KTile8(&sums, w, wp, base, blocksPerRow, xBlock, tn) } for t := 0; t < tn; t++ { out[n+t] = sums[t] } } } func q6kGemvDecodeTiled(out []float32, x []float32, w []tensor.BlockQ6_K, wp []tensor.Q6KDotParams, N, blocksPerRow, startN, endN int) { const tile = 8 for n := startN; n < endN; n += tile { tn := endN - n if tn > tile { tn = tile } var sums [tile]float32 for b := 0; b < blocksPerRow; b++ { xBlock := &x[b*256] base := n*blocksPerRow + b tensor.DotQ6KTile8(&sums, w, wp, base, blocksPerRow, xBlock, tn) } for t := 0; t < tn; t++ { out[n+t] = sums[t] } } } func q3kGemvDecodeTiled(out []float32, x []float32, w []tensor.BlockQ3_K, wp []tensor.Q3KDotParams, N, blocksPerRow, startN, endN int) { const tile = 8 for n := startN; n < endN; n += tile { tn := endN - n if tn > tile { tn = tile } var sums [tile]float32 for b := 0; b < blocksPerRow; b++ { xBlock := &x[b*256] base := n*blocksPerRow + b tensor.DotQ3KTile8(&sums, w, wp, base, blocksPerRow, xBlock, tn) } for t := 0; t < tn; t++ { out[n+t] = sums[t] } } } func q2kGemvDecodeTiled(out []float32, x []float32, w []tensor.BlockQ2_K, wp []tensor.Q2KDotParams, N, blocksPerRow, startN, endN int) { const tile = 8 for n := startN; n < endN; n += tile { tn := endN - n if tn > tile { tn = tile } var sums [tile]float32 for b := 0; b < blocksPerRow; b++ { xBlock := &x[b*256] base := n*blocksPerRow + b tensor.DotQ2KTile8(&sums, w, wp, base, blocksPerRow, xBlock, tn) } for t := 0; t < tn; t++ { out[n+t] = sums[t] } } } func q8kGemvDecodeTiled(out []float32, x []float32, w []tensor.BlockQ8_K, N, blocksPerRow, startN, endN int) { const tile = 8 for n := startN; n < endN; n += tile { tn := endN - n if tn > tile { tn = tile } var sums [tile]float32 for b := 0; b < blocksPerRow; b++ { xBlock := &x[b*256] base := n*blocksPerRow + b tensor.DotQ8KTile8(&sums, w, base, blocksPerRow, xBlock, tn) } for t := 0; t < tn; t++ { out[n+t] = sums[t] } } }