| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758 |
- package ops
- import (
- "makarna/pkg/backend/cpu"
- "makarna/pkg/tensor"
- )
- // Permute reorders tensor dimensions
- // Example: Permute(t, 1, 0, 2) swaps first two dims
- // Returns a new tensor (not a view - copies data for simplicity)
- func Permute(t *cpu.Tensor, order ...int) *cpu.Tensor {
- oldShape := t.Shape()
- oldData := t.DataFloat32()
-
- // Compute new shape
- newShape := make(tensor.Shape, len(order))
- for i, o := range order {
- newShape[i] = oldShape[o]
- }
-
- // Compute old strides
- oldStrides := make([]int, len(oldShape))
- oldStrides[len(oldShape)-1] = 1
- for i := len(oldShape) - 2; i >= 0; i-- {
- oldStrides[i] = oldStrides[i+1] * oldShape[i+1]
- }
-
- // Compute new strides
- newStrides := make([]int, len(newShape))
- newStrides[len(newShape)-1] = 1
- for i := len(newShape) - 2; i >= 0; i-- {
- newStrides[i] = newStrides[i+1] * newShape[i+1]
- }
-
- newData := make([]float32, newShape.NumElements())
-
- // Iterate over all new indices and map to old
- indices := make([]int, len(newShape))
- for i := 0; i < len(newData); i++ {
- // Compute old flat index
- oldFlatIdx := 0
- for d := 0; d < len(order); d++ {
- oldFlatIdx += indices[d] * oldStrides[order[d]]
- }
- newData[i] = oldData[oldFlatIdx]
-
- // Increment indices
- for d := len(indices) - 1; d >= 0; d-- {
- indices[d]++
- if indices[d] < newShape[d] {
- break
- }
- indices[d] = 0
- }
- }
-
- return cpu.NewTensor(newShape, newData)
- }
|