backend_test.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. package tests
  2. import (
  3. "math"
  4. "testing"
  5. "makarna/pkg/backend/cpu"
  6. "makarna/pkg/backend/cpu/matmul"
  7. "makarna/pkg/backend/cpu/nn"
  8. "makarna/pkg/tensor"
  9. )
  10. // Helper for approximate float comparison
  11. func assertNearlyEqual(t *testing.T, expected, actual []float32, epsilon float32, msg string) {
  12. if len(expected) != len(actual) {
  13. t.Fatalf("%s: length mismatch: expected %d, got %d", msg, len(expected), len(actual))
  14. }
  15. for i := range expected {
  16. diff := float32(math.Abs(float64(expected[i] - actual[i])))
  17. if diff > epsilon {
  18. t.Errorf("%s: element %d mismatch: expected %f, got %f (diff %f > %f)", msg, i, expected[i], actual[i], diff, epsilon)
  19. return // Fail fast
  20. }
  21. }
  22. }
  23. func TestLinear_F32(t *testing.T) {
  24. // A: [2, 3]
  25. // W: [2, 3] (2 output features, 3 input features)
  26. // C: [2, 2]
  27. aData := []float32{
  28. 1, 2, 3,
  29. 4, 5, 6,
  30. }
  31. wData := []float32{
  32. 0.1, 0.2, 0.3, // Neuron 0
  33. -0.1, -0.2, -0.3, // Neuron 1
  34. }
  35. a := cpu.NewTensor(tensor.Shape{2, 3}, aData)
  36. w := cpu.NewTensor(tensor.Shape{2, 3}, wData)
  37. c := cpu.NewTensor(tensor.Shape{2, 2}, nil)
  38. if err := matmul.Linear(a, w, c); err != nil {
  39. t.Fatalf("Linear failed: %v", err)
  40. }
  41. // Expected:
  42. // Row 0:
  43. // N0: 1*0.1 + 2*0.2 + 3*0.3 = 0.1+0.4+0.9 = 1.4
  44. // N1: 1*-0.1 + 2*-0.2 + 3*-0.3 = -0.1-0.4-0.9 = -1.4
  45. // Row 1:
  46. // N0: 4*0.1 + 5*0.2 + 6*0.3 = 0.4+1.0+1.8 = 3.2
  47. // N1: 4*-0.1 + 5*-0.2 + 6*-0.3 = -0.4-1.0-1.8 = -3.2
  48. expected := []float32{1.4, -1.4, 3.2, -3.2}
  49. assertNearlyEqual(t, expected, c.DataFloat32(), 1e-5, "Linear F32")
  50. }
  51. func TestEmbedding(t *testing.T) {
  52. // Vocab: 4, Dim: 3
  53. weights := []float32{
  54. 0, 0, 0, // ID 0
  55. 1, 1, 1, // ID 1
  56. 2, 2, 2, // ID 2
  57. 3, 3, 3, // ID 3
  58. }
  59. w := cpu.NewTensor(tensor.Shape{4, 3}, weights)
  60. // Sequence: [1, 3]
  61. ids := []int{1, 3}
  62. out := cpu.NewTensor(tensor.Shape{2, 3}, nil)
  63. if err := nn.Embedding(ids, w, out); err != nil {
  64. t.Fatalf("Embedding failed: %v", err)
  65. }
  66. expected := []float32{
  67. 1, 1, 1,
  68. 3, 3, 3,
  69. }
  70. assertNearlyEqual(t, expected, out.DataFloat32(), 1e-5, "Embedding")
  71. }
  72. func TestRMSNorm(t *testing.T) {
  73. // x: [1, 4] = [1, 2, 3, 4]
  74. // mean square = (1+4+9+16)/4 = 30/4 = 7.5
  75. // rms = sqrt(7.5 + eps) ~= 2.7386
  76. // weight: [1, 1, 1, 1]
  77. xData := []float32{1, 2, 3, 4}
  78. x := cpu.NewTensor(tensor.Shape{1, 4}, xData)
  79. w := cpu.NewTensor(tensor.Shape{4}, []float32{1, 1, 1, 1})
  80. err := nn.RMSNorm(x, w, 1e-5)
  81. if err != nil {
  82. t.Fatalf("RMSNorm failed: %v", err)
  83. }
  84. // Manual calc with Python: x / sqrt(mean(x**2) + eps)
  85. // 1/2.7386 = 0.3651
  86. // 2/2.7386 = 0.7303
  87. // ...
  88. expected := []float32{
  89. 0.365148, 0.730297, 1.095445, 1.460593,
  90. }
  91. assertNearlyEqual(t, expected, x.DataFloat32(), 1e-4, "RMSNorm")
  92. }
  93. func TestSoftmax(t *testing.T) {
  94. // Logits: [0, 1, 2]
  95. // exp: [1, 2.718, 7.389]
  96. // sum: 11.107
  97. // prob: [0.090, 0.244, 0.665]
  98. data := []float32{0, 1, 2}
  99. x := cpu.NewTensor(tensor.Shape{3}, data)
  100. nn.Softmax(x)
  101. expected := []float32{0.09003057, 0.24472847, 0.66524096}
  102. assertNearlyEqual(t, expected, x.DataFloat32(), 1e-5, "Softmax")
  103. }