nn_simd_test.go 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. package nn
  2. import (
  3. "math"
  4. "math/rand"
  5. "testing"
  6. "makarna/pkg/backend/cpu"
  7. "makarna/pkg/tensor"
  8. )
  9. func TestRMSNormMatchesReference(t *testing.T) {
  10. dim := 8
  11. x := cpu.NewTensor(tensor.Shape{1, dim}, []float32{1, 2, 3, 4, 5, 6, 7, 8})
  12. w := cpu.NewTensor(tensor.Shape{dim}, make([]float32, dim))
  13. for i := range w.DataFloat32() {
  14. w.DataFloat32()[i] = 1
  15. }
  16. if err := RMSNorm(x, w, 1e-5); err != nil {
  17. t.Fatalf("rmsnorm err: %v", err)
  18. }
  19. // Reference
  20. row := x.DataFloat32()
  21. var ss float32
  22. for _, v := range []float32{1, 2, 3, 4, 5, 6, 7, 8} {
  23. ss += v * v
  24. }
  25. ss /= float32(dim)
  26. inv := 1 / float32(math.Sqrt(float64(ss+1e-5)))
  27. for i, v := range []float32{1, 2, 3, 4, 5, 6, 7, 8} {
  28. want := v * inv
  29. if diff := absDiff(row[i], want); diff > 1e-4 {
  30. t.Fatalf("rmsnorm mismatch at %d: got %f want %f", i, row[i], want)
  31. }
  32. }
  33. }
  34. func TestSoftmaxSumsToOne(t *testing.T) {
  35. data := []float32{0.1, 1.2, -0.3, 0.4}
  36. x := cpu.NewTensor(tensor.Shape{len(data)}, append([]float32(nil), data...))
  37. if err := Softmax(x); err != nil {
  38. t.Fatalf("softmax err: %v", err)
  39. }
  40. sum := float32(0)
  41. for _, v := range x.DataFloat32() {
  42. sum += v
  43. if v <= 0 {
  44. t.Fatalf("softmax produced non-positive prob %f", v)
  45. }
  46. }
  47. if diff := absDiff(sum, 1); diff > 1e-5 {
  48. t.Fatalf("softmax sum != 1: got %f", sum)
  49. }
  50. }
  51. func TestRoPENoNaN(t *testing.T) {
  52. headDim := 4
  53. seq := 3
  54. data := make([]float32, seq*headDim)
  55. for i := range data {
  56. data[i] = rand.Float32()
  57. }
  58. x := cpu.NewTensor(tensor.Shape{seq, headDim}, data)
  59. positions := []int{0, 1, 2}
  60. if err := RoPE(x, positions, headDim, 10000); err != nil {
  61. t.Fatalf("rope err: %v", err)
  62. }
  63. for i, v := range x.DataFloat32() {
  64. if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
  65. t.Fatalf("rope produced invalid value at %d: %f", i, v)
  66. }
  67. }
  68. }
  69. func absDiff(a, b float32) float32 {
  70. if a > b {
  71. return a - b
  72. }
  73. return b - a
  74. }