package ops import ( "makarna/pkg/backend/cpu" "makarna/pkg/tensor" ) // Repeat repeats tensor n times along dimension dim func Repeat(t *cpu.Tensor, dim, n int) *cpu.Tensor { shape := t.Shape() data := t.DataFloat32() newShape := make(tensor.Shape, len(shape)) copy(newShape, shape) newShape[dim] *= n // Simple case: repeat along first dimension if dim == 0 { newData := make([]float32, 0, newShape.NumElements()) for i := 0; i < n; i++ { newData = append(newData, data...) } return cpu.NewTensor(newShape, newData) } // Simple case: repeat along last dimension if dim == len(shape)-1 { newData := make([]float32, newShape.NumElements()) rowSize := shape[dim] newRowSize := newShape[dim] numRows := len(data) / rowSize for row := 0; row < numRows; row++ { srcStart := row * rowSize dstStart := row * newRowSize for rep := 0; rep < n; rep++ { copy(newData[dstStart+rep*rowSize:], data[srcStart:srcStart+rowSize]) } } return cpu.NewTensor(newShape, newData) } // General case: repeat along middle dimension // Calculate outer (before dim), inner (after dim) sizes outerSize := 1 for i := 0; i < dim; i++ { outerSize *= shape[i] } innerSize := 1 for i := dim + 1; i < len(shape); i++ { innerSize *= shape[i] } dimSize := shape[dim] sliceSize := dimSize * innerSize newData := make([]float32, newShape.NumElements()) dstIdx := 0 for outer := 0; outer < outerSize; outer++ { srcStart := outer * sliceSize srcSlice := data[srcStart : srcStart+sliceSize] for rep := 0; rep < n; rep++ { copy(newData[dstIdx:], srcSlice) dstIdx += sliceSize } } return cpu.NewTensor(newShape, newData) } // RepeatInterleave repeats each element n times along dimension func RepeatInterleave(t *cpu.Tensor, dim, n int) *cpu.Tensor { shape := t.Shape() data := t.DataFloat32() newShape := make(tensor.Shape, len(shape)) copy(newShape, shape) newShape[dim] *= n // Handle 2D case with dim=0 if len(shape) == 2 && dim == 0 { newData := make([]float32, newShape.NumElements()) rowSize := shape[1] for row := 0; row < shape[0]; row++ { srcRow := data[row*rowSize : (row+1)*rowSize] for rep := 0; rep < n; rep++ { dstStart := (row*n + rep) * rowSize copy(newData[dstStart:], srcRow) } } return cpu.NewTensor(newShape, newData) } return nil }