| 123456789101112131415161718192021222324252627282930313233343536373839404142 |
- package nn
- import (
- "unsafe"
- "makarna/pkg/tensor"
- )
- // ParseTokenIDs extracts integer token IDs from a float32 tensor
- // Input tensor shape: [seqLen]
- func ParseTokenIDs(input tensor.Tensor) []int {
- seqLen := input.Shape()[0]
- ptr := input.Data().(unsafe.Pointer)
- slice := unsafe.Slice((*float32)(ptr), seqLen)
- ids := make([]int, seqLen)
- for i, v := range slice {
- ids[i] = int(v)
- }
- return ids
- }
- // ParsePositions extracts position indices from tensor or generates default [0,1,2,...]
- // If positions is nil, generates sequential positions starting from 0
- func ParsePositions(positions tensor.Tensor, seqLen int) []int {
- if positions == nil {
- arr := make([]int, seqLen)
- for i := range arr {
- arr[i] = i
- }
- return arr
- }
- ptr := positions.Data().(unsafe.Pointer)
- slice := unsafe.Slice((*float32)(ptr), seqLen)
- arr := make([]int, seqLen)
- for i, v := range slice {
- arr[i] = int(v)
- }
- return arr
- }
|