| 123456789101112131415161718192021222324252627282930313233343536373839404142434445 |
- package ops
- import (
- "makarna/pkg/backend/cpu"
- "makarna/pkg/tensor"
- )
- // Reshape returns a new tensor with different shape (same data)
- func Reshape(t *cpu.Tensor, shape tensor.Shape) *cpu.Tensor {
- return cpu.NewTensor(shape, t.DataFloat32())
- }
- // View is an alias for Reshape
- func View(t *cpu.Tensor, shape tensor.Shape) *cpu.Tensor {
- return Reshape(t, shape)
- }
- // Squeeze removes dimensions of size 1
- func Squeeze(t *cpu.Tensor) *cpu.Tensor {
- oldShape := t.Shape()
- newShape := make(tensor.Shape, 0)
- for _, s := range oldShape {
- if s != 1 {
- newShape = append(newShape, s)
- }
- }
- if len(newShape) == 0 {
- newShape = tensor.Shape{1} // Scalar
- }
- return cpu.NewTensor(newShape, t.DataFloat32())
- }
- // Unsqueeze adds a dimension of size 1 at the specified position
- func Unsqueeze(t *cpu.Tensor, dim int) *cpu.Tensor {
- oldShape := t.Shape()
- newShape := make(tensor.Shape, len(oldShape)+1)
- for i := 0; i < dim; i++ {
- newShape[i] = oldShape[i]
- }
- newShape[dim] = 1
- for i := dim; i < len(oldShape); i++ {
- newShape[i+1] = oldShape[i]
- }
- return cpu.NewTensor(newShape, t.DataFloat32())
- }
|