repeat.go 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. package ops
  2. import (
  3. "makarna/pkg/backend/cpu"
  4. "makarna/pkg/tensor"
  5. )
  6. // Repeat repeats tensor n times along dimension dim
  7. func Repeat(t *cpu.Tensor, dim, n int) *cpu.Tensor {
  8. shape := t.Shape()
  9. data := t.DataFloat32()
  10. newShape := make(tensor.Shape, len(shape))
  11. copy(newShape, shape)
  12. newShape[dim] *= n
  13. // Simple case: repeat along first dimension
  14. if dim == 0 {
  15. newData := make([]float32, 0, newShape.NumElements())
  16. for i := 0; i < n; i++ {
  17. newData = append(newData, data...)
  18. }
  19. return cpu.NewTensor(newShape, newData)
  20. }
  21. // Simple case: repeat along last dimension
  22. if dim == len(shape)-1 {
  23. newData := make([]float32, newShape.NumElements())
  24. rowSize := shape[dim]
  25. newRowSize := newShape[dim]
  26. numRows := len(data) / rowSize
  27. for row := 0; row < numRows; row++ {
  28. srcStart := row * rowSize
  29. dstStart := row * newRowSize
  30. for rep := 0; rep < n; rep++ {
  31. copy(newData[dstStart+rep*rowSize:], data[srcStart:srcStart+rowSize])
  32. }
  33. }
  34. return cpu.NewTensor(newShape, newData)
  35. }
  36. // General case: repeat along middle dimension
  37. // Calculate outer (before dim), inner (after dim) sizes
  38. outerSize := 1
  39. for i := 0; i < dim; i++ {
  40. outerSize *= shape[i]
  41. }
  42. innerSize := 1
  43. for i := dim + 1; i < len(shape); i++ {
  44. innerSize *= shape[i]
  45. }
  46. dimSize := shape[dim]
  47. sliceSize := dimSize * innerSize
  48. newData := make([]float32, newShape.NumElements())
  49. dstIdx := 0
  50. for outer := 0; outer < outerSize; outer++ {
  51. srcStart := outer * sliceSize
  52. srcSlice := data[srcStart : srcStart+sliceSize]
  53. for rep := 0; rep < n; rep++ {
  54. copy(newData[dstIdx:], srcSlice)
  55. dstIdx += sliceSize
  56. }
  57. }
  58. return cpu.NewTensor(newShape, newData)
  59. }
  60. // RepeatInterleave repeats each element n times along dimension
  61. func RepeatInterleave(t *cpu.Tensor, dim, n int) *cpu.Tensor {
  62. shape := t.Shape()
  63. data := t.DataFloat32()
  64. newShape := make(tensor.Shape, len(shape))
  65. copy(newShape, shape)
  66. newShape[dim] *= n
  67. // Handle 2D case with dim=0
  68. if len(shape) == 2 && dim == 0 {
  69. newData := make([]float32, newShape.NumElements())
  70. rowSize := shape[1]
  71. for row := 0; row < shape[0]; row++ {
  72. srcRow := data[row*rowSize : (row+1)*rowSize]
  73. for rep := 0; rep < n; rep++ {
  74. dstStart := (row*n + rep) * rowSize
  75. copy(newData[dstStart:], srcRow)
  76. }
  77. }
  78. return cpu.NewTensor(newShape, newData)
  79. }
  80. return nil
  81. }