package nn import ( "fmt" "makarna/pkg/backend/cpu" "makarna/pkg/tensor" ) // Embedding looks up token embeddings // ids: token IDs // weight: [vocab_size, dim] // out: [seq_len, dim] func Embedding(ids []int, weight, out *cpu.Tensor) error { inShape := weight.Shape() if len(inShape) != 2 { return fmt.Errorf("embedding: expected 2D weight, got %v", inShape) } vocabSize := inShape[0] dim := inShape[1] oData := out.DataFloat32() // Validate output shape outShape := out.Shape() if outShape[0] != len(ids) || outShape[1] != dim { return fmt.Errorf("embedding: output shape mismatch: expected [%d, %d], got %v", len(ids), dim, outShape) } switch weight.DType() { case tensor.Float32: wData := weight.DataFloat32() for i, id := range ids { if id < 0 || id >= vocabSize { return fmt.Errorf("embedding: id %d out of range [0, %d)", id, vocabSize) } src := wData[id*dim : (id+1)*dim] dst := oData[i*dim : (i+1)*dim] copy(dst, src) } case tensor.Float16: wData := weight.DataUint16() for i, id := range ids { if id < 0 || id >= vocabSize { return fmt.Errorf("embedding: id %d out of range [0, %d)", id, vocabSize) } src := wData[id*dim : (id+1)*dim] dst := oData[i*dim : (i+1)*dim] for j := 0; j < dim; j++ { dst[j] = float16BitsToFloat32(src[j]) } } case tensor.BFloat16: wData := weight.DataUint16() for i, id := range ids { if id < 0 || id >= vocabSize { return fmt.Errorf("embedding: id %d out of range [0, %d)", id, vocabSize) } src := wData[id*dim : (id+1)*dim] dst := oData[i*dim : (i+1)*dim] for j := 0; j < dim; j++ { dst[j] = bfloat16BitsToFloat32(src[j]) } } case tensor.Q4_K: const blockSize = 256 if dim%blockSize != 0 { return fmt.Errorf("embedding: Q4_K dim %d must be multiple of %d", dim, blockSize) } wData := weight.DataQ4_K() blocksPerDim := dim / blockSize var deqBuf [256]float32 for i, id := range ids { if id < 0 || id >= vocabSize { return fmt.Errorf("embedding: id %d out of range [0, %d)", id, vocabSize) } dst := oData[i*dim : (i+1)*dim] blockStart := id * blocksPerDim for b := 0; b < blocksPerDim; b++ { block := &wData[blockStart+b] tensor.DequantizeQ4_K(block, deqBuf[:]) copy(dst[b*blockSize:], deqBuf[:]) } } case tensor.Q8_K: const blockSize = 256 if dim%blockSize != 0 { return fmt.Errorf("embedding: Q8_K dim %d must be multiple of %d", dim, blockSize) } wData := weight.DataQ8_K() blocksPerDim := dim / blockSize var deqBuf [256]float32 for i, id := range ids { if id < 0 || id >= vocabSize { return fmt.Errorf("embedding: id %d out of range [0, %d)", id, vocabSize) } dst := oData[i*dim : (i+1)*dim] blockStart := id * blocksPerDim for b := 0; b < blocksPerDim; b++ { block := &wData[blockStart+b] tensor.DequantizeQ8_K(block, deqBuf[:]) copy(dst[b*blockSize:], deqBuf[:]) } } case tensor.Q3_K: const blockSize = 256 if dim%blockSize != 0 { return fmt.Errorf("embedding: Q3_K dim %d must be multiple of %d", dim, blockSize) } wData := weight.DataQ3_K() blocksPerDim := dim / blockSize var deqBuf [256]float32 for i, id := range ids { if id < 0 || id >= vocabSize { return fmt.Errorf("embedding: id %d out of range [0, %d)", id, vocabSize) } dst := oData[i*dim : (i+1)*dim] blockStart := id * blocksPerDim for b := 0; b < blocksPerDim; b++ { block := &wData[blockStart+b] tensor.DequantizeQ3_K(block, deqBuf[:]) copy(dst[b*blockSize:], deqBuf[:]) } } case tensor.Q5_K: const blockSize = 256 if dim%blockSize != 0 { return fmt.Errorf("embedding: Q5_K dim %d must be multiple of %d", dim, blockSize) } wData := weight.DataQ5_K() blocksPerDim := dim / blockSize var deqBuf [256]float32 for i, id := range ids { if id < 0 || id >= vocabSize { return fmt.Errorf("embedding: id %d out of range [0, %d)", id, vocabSize) } dst := oData[i*dim : (i+1)*dim] blockStart := id * blocksPerDim for b := 0; b < blocksPerDim; b++ { block := &wData[blockStart+b] tensor.DequantizeQ5_K(block, deqBuf[:]) copy(dst[b*blockSize:], deqBuf[:]) } } case tensor.Q6_K: const blockSize = 256 if dim%blockSize != 0 { return fmt.Errorf("embedding: Q6_K dim %d must be multiple of %d", dim, blockSize) } wData := weight.DataQ6_K() blocksPerDim := dim / blockSize for i, id := range ids { if id < 0 || id >= vocabSize { return fmt.Errorf("embedding: id %d out of range [0, %d)", id, vocabSize) } dst := oData[i*dim : (i+1)*dim] blockStart := id * blocksPerDim for b := 0; b < blocksPerDim; b++ { block := &wData[blockStart+b] seg := dst[b*blockSize : (b+1)*blockSize] tensor.DequantizeQ6_K(block, seg) } } case tensor.Q2_K: const blockSize = 256 if dim%blockSize != 0 { return fmt.Errorf("embedding: Q2_K dim %d must be multiple of %d", dim, blockSize) } wData := weight.DataQ2_K() blocksPerDim := dim / blockSize var deqBuf [256]float32 for i, id := range ids { if id < 0 || id >= vocabSize { return fmt.Errorf("embedding: id %d out of range [0, %d)", id, vocabSize) } dst := oData[i*dim : (i+1)*dim] blockStart := id * blocksPerDim for b := 0; b < blocksPerDim; b++ { block := &wData[blockStart+b] tensor.DequantizeQ2_K(block, deqBuf[:]) copy(dst[b*blockSize:], deqBuf[:]) } } default: return fmt.Errorf("embedding: unsupported weight dtype %v", weight.DType()) } return nil }