slice.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. package ops
  2. import (
  3. "makarna/pkg/backend/cpu"
  4. "makarna/pkg/tensor"
  5. )
  6. // Slice extracts a portion of tensor along specified dimension
  7. // dim: dimension to slice
  8. // start, end: range [start, end)
  9. func Slice(t *cpu.Tensor, dim, start, end int) *cpu.Tensor {
  10. shape := t.Shape()
  11. data := t.DataFloat32()
  12. newShape := make(tensor.Shape, len(shape))
  13. copy(newShape, shape)
  14. newShape[dim] = end - start
  15. // Calculate strides
  16. strides := make([]int, len(shape))
  17. strides[len(shape)-1] = 1
  18. for i := len(shape) - 2; i >= 0; i-- {
  19. strides[i] = strides[i+1] * shape[i+1]
  20. }
  21. newData := make([]float32, newShape.NumElements())
  22. // Iterate and copy
  23. newIdx := 0
  24. indices := make([]int, len(shape))
  25. for i := 0; i < len(data); i++ {
  26. // Check if this index is within slice range for target dim
  27. if indices[dim] >= start && indices[dim] < end {
  28. newData[newIdx] = data[i]
  29. newIdx++
  30. }
  31. // Increment indices
  32. for d := len(indices) - 1; d >= 0; d-- {
  33. indices[d]++
  34. if indices[d] < shape[d] {
  35. break
  36. }
  37. indices[d] = 0
  38. }
  39. }
  40. return cpu.NewTensor(newShape, newData)
  41. }
  42. // Concat concatenates tensors along specified dimension
  43. func Concat(tensors []*cpu.Tensor, dim int) *cpu.Tensor {
  44. if len(tensors) == 0 {
  45. return nil
  46. }
  47. if len(tensors) == 1 {
  48. return tensors[0]
  49. }
  50. // Calculate new shape
  51. refShape := tensors[0].Shape()
  52. newShape := make(tensor.Shape, len(refShape))
  53. copy(newShape, refShape)
  54. totalDim := 0
  55. for _, t := range tensors {
  56. totalDim += t.Shape()[dim]
  57. }
  58. newShape[dim] = totalDim
  59. // Simple case: concat along last dimension
  60. if dim == len(refShape)-1 {
  61. newData := make([]float32, 0, newShape.NumElements())
  62. // For each row, append all tensors' data
  63. numRows := refShape.NumElements() / refShape[dim]
  64. for row := 0; row < numRows; row++ {
  65. for _, t := range tensors {
  66. tData := t.DataFloat32()
  67. rowSize := t.Shape()[dim]
  68. start := row * rowSize
  69. newData = append(newData, tData[start:start+rowSize]...)
  70. }
  71. }
  72. return cpu.NewTensor(newShape, newData)
  73. }
  74. // General case: just concatenate flat data (works for dim=0)
  75. if dim == 0 {
  76. newData := make([]float32, 0, newShape.NumElements())
  77. for _, t := range tensors {
  78. newData = append(newData, t.DataFloat32()...)
  79. }
  80. return cpu.NewTensor(newShape, newData)
  81. }
  82. // General case: concat along middle dimension
  83. // Calculate outer (before dim), inner (after dim) sizes
  84. outerSize := 1
  85. for i := 0; i < dim; i++ {
  86. outerSize *= refShape[i]
  87. }
  88. innerSize := 1
  89. for i := dim + 1; i < len(refShape); i++ {
  90. innerSize *= refShape[i]
  91. }
  92. newData := make([]float32, 0, newShape.NumElements())
  93. // For each outer index, copy all tensors' slices
  94. for outer := 0; outer < outerSize; outer++ {
  95. for _, t := range tensors {
  96. tData := t.DataFloat32()
  97. dimSize := t.Shape()[dim]
  98. sliceSize := dimSize * innerSize
  99. start := outer * sliceSize
  100. newData = append(newData, tData[start:start+sliceSize]...)
  101. }
  102. }
  103. return cpu.NewTensor(newShape, newData)
  104. }