| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102 |
- //go:build cuda
- package compute
- import (
- "fmt"
- "unsafe"
- "makarna/pkg/backend/cpu"
- "makarna/pkg/backend/cpu/nn"
- "makarna/pkg/backend/cuda"
- "makarna/pkg/backend/device"
- "makarna/pkg/tensor"
- )
- func HybridTokenEmbedding(ids []int, tokenEmb tensor.Tensor, hiddenSize int, dispatcher *device.DeviceDispatcher) (*Activation, error) {
- seqLen := len(ids)
- // Fast-path: single-token decode, first layer on GPU, embedding weights are quantized.
- // We dequantize only the selected row directly on GPU to avoid CPU dequant + H2D copy.
- if seqLen == 1 && dispatcher != nil {
- p0 := dispatcher.LayerPlacement(0).Normalize()
- if p0.Type == tensor.CUDA && cuda.Available() {
- gpu := p0.GPU
- if gpu < 0 {
- gpu = 0
- }
- if embCPU, ok := tokenEmb.(*cpu.Tensor); ok {
- dtype := embCPU.DType()
- blocksPerDim := hiddenSize / 256
- if blocksPerDim > 0 && hiddenSize%256 == 0 {
- cache := GetWeightCache(gpu)
- wKey := fmt.Sprintf("tok_emb_%p", embCPU)
- gpuW, ok2 := cache.Get(wKey)
- var err error
- if !ok2 {
- gpuW, err = cache.Upload(wKey, embCPU)
- if err != nil {
- return nil, err
- }
- }
- var bytesPerBlock int
- switch dtype {
- case tensor.Q8_K:
- bytesPerBlock = 292
- case tensor.Q5_K:
- bytesPerBlock = 176
- case tensor.Q4_K:
- bytesPerBlock = 144
- case tensor.Q2_K:
- bytesPerBlock = 84
- case tensor.Q3_K:
- bytesPerBlock = 110
- case tensor.Q6_K:
- bytesPerBlock = 210
- default:
- bytesPerBlock = 0
- }
- if bytesPerBlock > 0 {
- hiddenGPU, err2 := cuda.NewTensor(tensor.Shape{1, hiddenSize}, tensor.Float32, gpu)
- if err2 == nil {
- id := ids[0]
- rowOff := uintptr(id * blocksPerDim * bytesPerBlock)
- rowPtr := unsafe.Pointer(uintptr(gpuW) + rowOff)
- outPtr := hiddenGPU.Data().(unsafe.Pointer)
- switch dtype {
- case tensor.Q8_K:
- err2 = cuda.DequantQ8K(rowPtr, outPtr, blocksPerDim, gpu)
- case tensor.Q5_K:
- err2 = cuda.DequantQ5K(rowPtr, outPtr, blocksPerDim, gpu)
- case tensor.Q4_K:
- err2 = cuda.DequantQ4K(rowPtr, outPtr, blocksPerDim, gpu)
- case tensor.Q2_K:
- err2 = cuda.DequantQ2K(rowPtr, outPtr, blocksPerDim, gpu)
- case tensor.Q3_K:
- err2 = cuda.DequantQ3K(rowPtr, outPtr, blocksPerDim, gpu)
- case tensor.Q6_K:
- err2 = cuda.DequantQ6K(rowPtr, outPtr, blocksPerDim, gpu)
- }
- if err2 == nil {
- return NewActivationFrom(hiddenGPU), nil
- }
- }
- }
- }
- }
- }
- }
- // Generic CPU embedding.
- embCPU, err := ToCPU(tokenEmb)
- if err != nil {
- return nil, err
- }
- hiddenCPU := cpu.NewTensor(tensor.Shape{seqLen, hiddenSize}, nil)
- if err := nn.Embedding(ids, embCPU, hiddenCPU); err != nil {
- return nil, err
- }
- return NewActivationFrom(hiddenCPU), nil
- }
|