tensor_utils.go 384 B

123456789101112131415161718192021
  1. package nn
  2. import (
  3. "fmt"
  4. "makarna/pkg/backend/cpu"
  5. )
  6. func FlattenVector(t *cpu.Tensor, n int, name string) ([]float32, error) {
  7. if t == nil {
  8. return nil, fmt.Errorf("missing %s", name)
  9. }
  10. data := t.DataFloat32()
  11. if len(data) == n {
  12. return data, nil
  13. }
  14. if len(data) >= n {
  15. return data[:n], nil
  16. }
  17. return nil, fmt.Errorf("%s has unexpected size %d", name, len(data))
  18. }