| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 |
- package nn
- import (
- "math"
- "math/rand"
- "testing"
- "makarna/pkg/backend/cpu"
- "makarna/pkg/tensor"
- )
- func TestRMSNormMatchesReference(t *testing.T) {
- dim := 8
- x := cpu.NewTensor(tensor.Shape{1, dim}, []float32{1, 2, 3, 4, 5, 6, 7, 8})
- w := cpu.NewTensor(tensor.Shape{dim}, make([]float32, dim))
- for i := range w.DataFloat32() {
- w.DataFloat32()[i] = 1
- }
- if err := RMSNorm(x, w, 1e-5); err != nil {
- t.Fatalf("rmsnorm err: %v", err)
- }
- // Reference
- row := x.DataFloat32()
- var ss float32
- for _, v := range []float32{1, 2, 3, 4, 5, 6, 7, 8} {
- ss += v * v
- }
- ss /= float32(dim)
- inv := 1 / float32(math.Sqrt(float64(ss+1e-5)))
- for i, v := range []float32{1, 2, 3, 4, 5, 6, 7, 8} {
- want := v * inv
- if diff := absDiff(row[i], want); diff > 1e-4 {
- t.Fatalf("rmsnorm mismatch at %d: got %f want %f", i, row[i], want)
- }
- }
- }
- func TestSoftmaxSumsToOne(t *testing.T) {
- data := []float32{0.1, 1.2, -0.3, 0.4}
- x := cpu.NewTensor(tensor.Shape{len(data)}, append([]float32(nil), data...))
- if err := Softmax(x); err != nil {
- t.Fatalf("softmax err: %v", err)
- }
- sum := float32(0)
- for _, v := range x.DataFloat32() {
- sum += v
- if v <= 0 {
- t.Fatalf("softmax produced non-positive prob %f", v)
- }
- }
- if diff := absDiff(sum, 1); diff > 1e-5 {
- t.Fatalf("softmax sum != 1: got %f", sum)
- }
- }
- func TestRoPENoNaN(t *testing.T) {
- headDim := 4
- seq := 3
- data := make([]float32, seq*headDim)
- for i := range data {
- data[i] = rand.Float32()
- }
- x := cpu.NewTensor(tensor.Shape{seq, headDim}, data)
- positions := []int{0, 1, 2}
- if err := RoPE(x, positions, headDim, 10000); err != nil {
- t.Fatalf("rope err: %v", err)
- }
- for i, v := range x.DataFloat32() {
- if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
- t.Fatalf("rope produced invalid value at %d: %f", i, v)
- }
- }
- }
- func absDiff(a, b float32) float32 {
- if a > b {
- return a - b
- }
- return b - a
- }
|