reshape.go 1.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. package ops
  2. import (
  3. "makarna/pkg/backend/cpu"
  4. "makarna/pkg/tensor"
  5. )
  6. // Reshape returns a new tensor with different shape (same data)
  7. func Reshape(t *cpu.Tensor, shape tensor.Shape) *cpu.Tensor {
  8. return cpu.NewTensor(shape, t.DataFloat32())
  9. }
  10. // View is an alias for Reshape
  11. func View(t *cpu.Tensor, shape tensor.Shape) *cpu.Tensor {
  12. return Reshape(t, shape)
  13. }
  14. // Squeeze removes dimensions of size 1
  15. func Squeeze(t *cpu.Tensor) *cpu.Tensor {
  16. oldShape := t.Shape()
  17. newShape := make(tensor.Shape, 0)
  18. for _, s := range oldShape {
  19. if s != 1 {
  20. newShape = append(newShape, s)
  21. }
  22. }
  23. if len(newShape) == 0 {
  24. newShape = tensor.Shape{1} // Scalar
  25. }
  26. return cpu.NewTensor(newShape, t.DataFloat32())
  27. }
  28. // Unsqueeze adds a dimension of size 1 at the specified position
  29. func Unsqueeze(t *cpu.Tensor, dim int) *cpu.Tensor {
  30. oldShape := t.Shape()
  31. newShape := make(tensor.Shape, len(oldShape)+1)
  32. for i := 0; i < dim; i++ {
  33. newShape[i] = oldShape[i]
  34. }
  35. newShape[dim] = 1
  36. for i := dim; i < len(oldShape); i++ {
  37. newShape[i+1] = oldShape[i]
  38. }
  39. return cpu.NewTensor(newShape, t.DataFloat32())
  40. }