1
0

compute_test.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. package compute
  2. import (
  3. "testing"
  4. "makarna/pkg/backend/cpu"
  5. "makarna/pkg/backend/device"
  6. "makarna/pkg/tensor"
  7. )
  8. func TestLinearCPU(t *testing.T) {
  9. // Input: [2, 3] @ Weight: [4, 3] = Output: [2, 4]
  10. input := cpu.NewTensor(tensor.Shape{2, 3}, []float32{
  11. 1, 2, 3,
  12. 4, 5, 6,
  13. })
  14. weight := cpu.NewTensor(tensor.Shape{4, 3}, []float32{
  15. 1, 0, 0,
  16. 0, 1, 0,
  17. 0, 0, 1,
  18. 1, 1, 1,
  19. })
  20. output := cpu.NewTensor(tensor.Shape{2, 4}, nil)
  21. ctx := NewContext(nil, 0) // nil dispatcher = CPU
  22. if err := Linear(ctx, input, weight, output); err != nil {
  23. t.Fatalf("Linear failed: %v", err)
  24. }
  25. expected := []float32{
  26. 1, 2, 3, 6, // row 0: [1,2,3] dot each weight row
  27. 4, 5, 6, 15, // row 1
  28. }
  29. outData := output.DataFloat32()
  30. for i, exp := range expected {
  31. if diff := outData[i] - exp; diff < -0.001 || diff > 0.001 {
  32. t.Errorf("output[%d] = %f, expected %f", i, outData[i], exp)
  33. }
  34. }
  35. }
  36. func TestRMSNorm(t *testing.T) {
  37. x := cpu.NewTensor(tensor.Shape{1, 4}, []float32{1, 2, 3, 4})
  38. w := cpu.NewTensor(tensor.Shape{4}, []float32{1, 1, 1, 1})
  39. ctx := NewContext(nil, 0)
  40. if err := RMSNorm(ctx, x, w, 1e-6); err != nil {
  41. t.Fatalf("RMSNorm failed: %v", err)
  42. }
  43. // Check output is normalized
  44. data := x.DataFloat32()
  45. var ss float32
  46. for _, v := range data {
  47. ss += v * v
  48. }
  49. rms := ss / 4
  50. // After RMSNorm, variance should be close to 1
  51. if rms < 0.9 || rms > 1.1 {
  52. t.Errorf("RMS after norm = %f, expected ~1.0", rms)
  53. }
  54. }
  55. func TestDeviceDispatcher(t *testing.T) {
  56. placements := []tensor.DevicePlacement{
  57. {Type: tensor.CUDA, GPU: 0},
  58. {Type: tensor.CUDA, GPU: 0},
  59. {Type: tensor.CPU, GPU: -1},
  60. {Type: tensor.CPU, GPU: -1},
  61. }
  62. dd := device.NewDeviceDispatcher(placements)
  63. if dd.NumGPULayers() != 2 {
  64. t.Errorf("NumGPULayers = %d, expected 2", dd.NumGPULayers())
  65. }
  66. if !dd.IsLayerOnGPU(0) {
  67. t.Error("Layer 0 should be on GPU")
  68. }
  69. if dd.IsLayerOnGPU(2) {
  70. t.Error("Layer 2 should be on CPU")
  71. }
  72. p := dd.LayerPlacement(1)
  73. if p.Type != tensor.CUDA {
  74. t.Errorf("Layer 1 placement = %v, expected CUDA", p.Type)
  75. }
  76. // Beyond bounds defaults to CPU
  77. p = dd.LayerPlacement(100)
  78. if p.Type != tensor.CPU {
  79. t.Errorf("Out of bounds placement = %v, expected CPU", p.Type)
  80. }
  81. }
  82. func TestContextPlacement(t *testing.T) {
  83. placements := []tensor.DevicePlacement{
  84. {Type: tensor.CUDA, GPU: 0},
  85. {Type: tensor.CPU, GPU: -1},
  86. }
  87. dd := device.NewDeviceDispatcher(placements)
  88. ctx0 := NewContext(dd, 0)
  89. if !ctx0.IsGPU() {
  90. t.Error("Context 0 should be GPU")
  91. }
  92. ctx1 := NewContext(dd, 1)
  93. if ctx1.IsGPU() {
  94. t.Error("Context 1 should be CPU")
  95. }
  96. // Nil dispatcher
  97. ctxNil := NewContext(nil, 0)
  98. if ctxNil.IsGPU() {
  99. t.Error("Nil dispatcher should default to CPU")
  100. }
  101. }
  102. func TestSwiGLU(t *testing.T) {
  103. gate := cpu.NewTensor(tensor.Shape{2}, []float32{0, 1})
  104. up := cpu.NewTensor(tensor.Shape{2}, []float32{2, 3})
  105. out := cpu.NewTensor(tensor.Shape{2}, nil)
  106. ctx := NewContext(nil, 0)
  107. if err := SwiGLU(ctx, gate, up, out); err != nil {
  108. t.Fatalf("SwiGLU failed: %v", err)
  109. }
  110. // SiLU(0) = 0, so out[0] = 0 * 2 = 0
  111. // SiLU(1) ≈ 0.731, so out[1] ≈ 0.731 * 3 ≈ 2.19
  112. data := out.DataFloat32()
  113. if data[0] != 0 {
  114. t.Errorf("out[0] = %f, expected 0", data[0])
  115. }
  116. if data[1] < 2.0 || data[1] > 2.5 {
  117. t.Errorf("out[1] = %f, expected ~2.2", data[1])
  118. }
  119. }