gemv_f32_tiled_test.go 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. package matmul
  2. import (
  3. "math"
  4. "math/rand"
  5. "testing"
  6. "makarna/pkg/backend/cpu"
  7. )
  8. func TestGemvFloat32RangeMatchesScalar(t *testing.T) {
  9. rng := rand.New(rand.NewSource(1))
  10. cases := []struct {
  11. K int
  12. N int
  13. startN int
  14. endN int
  15. }{
  16. {K: 3, N: 13, startN: 0, endN: 13},
  17. {K: 33, N: 17, startN: 0, endN: 17},
  18. {K: 33, N: 17, startN: 5, endN: 13},
  19. {K: 32, N: 16, startN: 0, endN: 16},
  20. }
  21. for _, tc := range cases {
  22. a := make([]float32, tc.K)
  23. for i := range a {
  24. a[i] = rng.Float32()*2 - 1
  25. }
  26. b := make([]float32, tc.N*tc.K)
  27. for i := range b {
  28. b[i] = rng.Float32()*2 - 1
  29. }
  30. want := make([]float32, tc.N)
  31. for n := tc.startN; n < tc.endN; n++ {
  32. var sum float32
  33. rowOff := n * tc.K
  34. for k := 0; k < tc.K; k++ {
  35. sum += a[k] * b[rowOff+k]
  36. }
  37. want[n] = sum
  38. }
  39. // Sanity: cpu.DotFloat32 should broadly match the scalar sum.
  40. for n := tc.startN; n < tc.endN; n++ {
  41. rowOff := n * tc.K
  42. gotCPU := cpu.DotFloat32(a, b[rowOff:rowOff+tc.K])
  43. diff := math.Abs(float64(gotCPU - want[n]))
  44. if diff > 1e-4 {
  45. t.Fatalf("DotFloat32 mismatch K=%d n=%d: got=%v want=%v", tc.K, n, gotCPU, want[n])
  46. }
  47. }
  48. got := make([]float32, tc.N)
  49. gemvFloat32Range(got, a, b, tc.K, tc.startN, tc.endN)
  50. const tol = 1e-4
  51. for n := tc.startN; n < tc.endN; n++ {
  52. diff := math.Abs(float64(got[n] - want[n]))
  53. if diff > tol {
  54. t.Fatalf("K=%d N=%d range=[%d,%d) n=%d: got=%v want=%v diff=%g",
  55. tc.K, tc.N, tc.startN, tc.endN, n, got[n], want[n], diff)
  56. }
  57. }
  58. }
  59. }