| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229 |
- package nn
- import (
- "fmt"
- "makarna/pkg/backend/cpu"
- "makarna/pkg/tensor"
- )
- // Embedding looks up token embeddings
- // ids: token IDs
- // weight: [vocab_size, dim]
- // out: [seq_len, dim]
- func Embedding(ids []int, weight, out *cpu.Tensor) error {
- inShape := weight.Shape()
- if len(inShape) != 2 {
- return fmt.Errorf("embedding: expected 2D weight, got %v", inShape)
- }
- vocabSize := inShape[0]
- dim := inShape[1]
- oData := out.DataFloat32()
- // Validate output shape
- outShape := out.Shape()
- if outShape[0] != len(ids) || outShape[1] != dim {
- return fmt.Errorf("embedding: output shape mismatch: expected [%d, %d], got %v", len(ids), dim, outShape)
- }
- switch weight.DType() {
- case tensor.Float32:
- wData := weight.DataFloat32()
- for i, id := range ids {
- if id < 0 || id >= vocabSize {
- return fmt.Errorf("embedding: id %d out of range [0, %d)", id, vocabSize)
- }
- src := wData[id*dim : (id+1)*dim]
- dst := oData[i*dim : (i+1)*dim]
- copy(dst, src)
- }
- case tensor.Float16:
- wData := weight.DataUint16()
- for i, id := range ids {
- if id < 0 || id >= vocabSize {
- return fmt.Errorf("embedding: id %d out of range [0, %d)", id, vocabSize)
- }
- src := wData[id*dim : (id+1)*dim]
- dst := oData[i*dim : (i+1)*dim]
- for j := 0; j < dim; j++ {
- dst[j] = float16BitsToFloat32(src[j])
- }
- }
- case tensor.BFloat16:
- wData := weight.DataUint16()
- for i, id := range ids {
- if id < 0 || id >= vocabSize {
- return fmt.Errorf("embedding: id %d out of range [0, %d)", id, vocabSize)
- }
- src := wData[id*dim : (id+1)*dim]
- dst := oData[i*dim : (i+1)*dim]
- for j := 0; j < dim; j++ {
- dst[j] = bfloat16BitsToFloat32(src[j])
- }
- }
- case tensor.Q4_K:
- const blockSize = 256
- if dim%blockSize != 0 {
- return fmt.Errorf("embedding: Q4_K dim %d must be multiple of %d", dim, blockSize)
- }
- wData := weight.DataQ4_K()
- blocksPerDim := dim / blockSize
- var deqBuf [256]float32
- for i, id := range ids {
- if id < 0 || id >= vocabSize {
- return fmt.Errorf("embedding: id %d out of range [0, %d)", id, vocabSize)
- }
- dst := oData[i*dim : (i+1)*dim]
- blockStart := id * blocksPerDim
- for b := 0; b < blocksPerDim; b++ {
- block := &wData[blockStart+b]
- tensor.DequantizeQ4_K(block, deqBuf[:])
- copy(dst[b*blockSize:], deqBuf[:])
- }
- }
- case tensor.Q8_K:
- const blockSize = 256
- if dim%blockSize != 0 {
- return fmt.Errorf("embedding: Q8_K dim %d must be multiple of %d", dim, blockSize)
- }
- wData := weight.DataQ8_K()
- blocksPerDim := dim / blockSize
- var deqBuf [256]float32
- for i, id := range ids {
- if id < 0 || id >= vocabSize {
- return fmt.Errorf("embedding: id %d out of range [0, %d)", id, vocabSize)
- }
- dst := oData[i*dim : (i+1)*dim]
- blockStart := id * blocksPerDim
- for b := 0; b < blocksPerDim; b++ {
- block := &wData[blockStart+b]
- tensor.DequantizeQ8_K(block, deqBuf[:])
- copy(dst[b*blockSize:], deqBuf[:])
- }
- }
- case tensor.Q3_K:
- const blockSize = 256
- if dim%blockSize != 0 {
- return fmt.Errorf("embedding: Q3_K dim %d must be multiple of %d", dim, blockSize)
- }
- wData := weight.DataQ3_K()
- blocksPerDim := dim / blockSize
- var deqBuf [256]float32
- for i, id := range ids {
- if id < 0 || id >= vocabSize {
- return fmt.Errorf("embedding: id %d out of range [0, %d)", id, vocabSize)
- }
- dst := oData[i*dim : (i+1)*dim]
- blockStart := id * blocksPerDim
- for b := 0; b < blocksPerDim; b++ {
- block := &wData[blockStart+b]
- tensor.DequantizeQ3_K(block, deqBuf[:])
- copy(dst[b*blockSize:], deqBuf[:])
- }
- }
- case tensor.Q5_K:
- const blockSize = 256
- if dim%blockSize != 0 {
- return fmt.Errorf("embedding: Q5_K dim %d must be multiple of %d", dim, blockSize)
- }
- wData := weight.DataQ5_K()
- blocksPerDim := dim / blockSize
- var deqBuf [256]float32
- for i, id := range ids {
- if id < 0 || id >= vocabSize {
- return fmt.Errorf("embedding: id %d out of range [0, %d)", id, vocabSize)
- }
- dst := oData[i*dim : (i+1)*dim]
- blockStart := id * blocksPerDim
- for b := 0; b < blocksPerDim; b++ {
- block := &wData[blockStart+b]
- tensor.DequantizeQ5_K(block, deqBuf[:])
- copy(dst[b*blockSize:], deqBuf[:])
- }
- }
- case tensor.Q6_K:
- const blockSize = 256
- if dim%blockSize != 0 {
- return fmt.Errorf("embedding: Q6_K dim %d must be multiple of %d", dim, blockSize)
- }
- wData := weight.DataQ6_K()
- blocksPerDim := dim / blockSize
- for i, id := range ids {
- if id < 0 || id >= vocabSize {
- return fmt.Errorf("embedding: id %d out of range [0, %d)", id, vocabSize)
- }
- dst := oData[i*dim : (i+1)*dim]
- blockStart := id * blocksPerDim
- for b := 0; b < blocksPerDim; b++ {
- block := &wData[blockStart+b]
- seg := dst[b*blockSize : (b+1)*blockSize]
- tensor.DequantizeQ6_K(block, seg)
- }
- }
- case tensor.Q2_K:
- const blockSize = 256
- if dim%blockSize != 0 {
- return fmt.Errorf("embedding: Q2_K dim %d must be multiple of %d", dim, blockSize)
- }
- wData := weight.DataQ2_K()
- blocksPerDim := dim / blockSize
- var deqBuf [256]float32
- for i, id := range ids {
- if id < 0 || id >= vocabSize {
- return fmt.Errorf("embedding: id %d out of range [0, %d)", id, vocabSize)
- }
- dst := oData[i*dim : (i+1)*dim]
- blockStart := id * blocksPerDim
- for b := 0; b < blocksPerDim; b++ {
- block := &wData[blockStart+b]
- tensor.DequantizeQ2_K(block, deqBuf[:])
- copy(dst[b*blockSize:], deqBuf[:])
- }
- }
- default:
- return fmt.Errorf("embedding: unsupported weight dtype %v", weight.DType())
- }
- return nil
- }
|