input.go 925 B

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. package nn
  2. import (
  3. "unsafe"
  4. "makarna/pkg/tensor"
  5. )
  6. // ParseTokenIDs extracts integer token IDs from a float32 tensor
  7. // Input tensor shape: [seqLen]
  8. func ParseTokenIDs(input tensor.Tensor) []int {
  9. seqLen := input.Shape()[0]
  10. ptr := input.Data().(unsafe.Pointer)
  11. slice := unsafe.Slice((*float32)(ptr), seqLen)
  12. ids := make([]int, seqLen)
  13. for i, v := range slice {
  14. ids[i] = int(v)
  15. }
  16. return ids
  17. }
  18. // ParsePositions extracts position indices from tensor or generates default [0,1,2,...]
  19. // If positions is nil, generates sequential positions starting from 0
  20. func ParsePositions(positions tensor.Tensor, seqLen int) []int {
  21. if positions == nil {
  22. arr := make([]int, seqLen)
  23. for i := range arr {
  24. arr[i] = i
  25. }
  26. return arr
  27. }
  28. ptr := positions.Data().(unsafe.Pointer)
  29. slice := unsafe.Slice((*float32)(ptr), seqLen)
  30. arr := make([]int, seqLen)
  31. for i, v := range slice {
  32. arr[i] = int(v)
  33. }
  34. return arr
  35. }