permute.go 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. package ops
  2. import (
  3. "makarna/pkg/backend/cpu"
  4. "makarna/pkg/tensor"
  5. )
  6. // Permute reorders tensor dimensions
  7. // Example: Permute(t, 1, 0, 2) swaps first two dims
  8. // Returns a new tensor (not a view - copies data for simplicity)
  9. func Permute(t *cpu.Tensor, order ...int) *cpu.Tensor {
  10. oldShape := t.Shape()
  11. oldData := t.DataFloat32()
  12. // Compute new shape
  13. newShape := make(tensor.Shape, len(order))
  14. for i, o := range order {
  15. newShape[i] = oldShape[o]
  16. }
  17. // Compute old strides
  18. oldStrides := make([]int, len(oldShape))
  19. oldStrides[len(oldShape)-1] = 1
  20. for i := len(oldShape) - 2; i >= 0; i-- {
  21. oldStrides[i] = oldStrides[i+1] * oldShape[i+1]
  22. }
  23. // Compute new strides
  24. newStrides := make([]int, len(newShape))
  25. newStrides[len(newShape)-1] = 1
  26. for i := len(newShape) - 2; i >= 0; i-- {
  27. newStrides[i] = newStrides[i+1] * newShape[i+1]
  28. }
  29. newData := make([]float32, newShape.NumElements())
  30. // Iterate over all new indices and map to old
  31. indices := make([]int, len(newShape))
  32. for i := 0; i < len(newData); i++ {
  33. // Compute old flat index
  34. oldFlatIdx := 0
  35. for d := 0; d < len(order); d++ {
  36. oldFlatIdx += indices[d] * oldStrides[order[d]]
  37. }
  38. newData[i] = oldData[oldFlatIdx]
  39. // Increment indices
  40. for d := len(indices) - 1; d >= 0; d-- {
  41. indices[d]++
  42. if indices[d] < newShape[d] {
  43. break
  44. }
  45. indices[d] = 0
  46. }
  47. }
  48. return cpu.NewTensor(newShape, newData)
  49. }