//go:build cuda package compute import ( "fmt" "unsafe" "makarna/pkg/backend/cpu" "makarna/pkg/backend/cpu/nn" "makarna/pkg/backend/cuda" "makarna/pkg/backend/device" "makarna/pkg/tensor" ) func HybridTokenEmbedding(ids []int, tokenEmb tensor.Tensor, hiddenSize int, dispatcher *device.DeviceDispatcher) (*Activation, error) { seqLen := len(ids) // Fast-path: single-token decode, first layer on GPU, embedding weights are quantized. // We dequantize only the selected row directly on GPU to avoid CPU dequant + H2D copy. if seqLen == 1 && dispatcher != nil { p0 := dispatcher.LayerPlacement(0).Normalize() if p0.Type == tensor.CUDA && cuda.Available() { gpu := p0.GPU if gpu < 0 { gpu = 0 } if embCPU, ok := tokenEmb.(*cpu.Tensor); ok { dtype := embCPU.DType() blocksPerDim := hiddenSize / 256 if blocksPerDim > 0 && hiddenSize%256 == 0 { cache := GetWeightCache(gpu) wKey := fmt.Sprintf("tok_emb_%p", embCPU) gpuW, ok2 := cache.Get(wKey) var err error if !ok2 { gpuW, err = cache.Upload(wKey, embCPU) if err != nil { return nil, err } } var bytesPerBlock int switch dtype { case tensor.Q8_K: bytesPerBlock = 292 case tensor.Q5_K: bytesPerBlock = 176 case tensor.Q4_K: bytesPerBlock = 144 case tensor.Q2_K: bytesPerBlock = 84 case tensor.Q3_K: bytesPerBlock = 110 case tensor.Q6_K: bytesPerBlock = 210 default: bytesPerBlock = 0 } if bytesPerBlock > 0 { hiddenGPU, err2 := cuda.NewTensor(tensor.Shape{1, hiddenSize}, tensor.Float32, gpu) if err2 == nil { id := ids[0] rowOff := uintptr(id * blocksPerDim * bytesPerBlock) rowPtr := unsafe.Pointer(uintptr(gpuW) + rowOff) outPtr := hiddenGPU.Data().(unsafe.Pointer) switch dtype { case tensor.Q8_K: err2 = cuda.DequantQ8K(rowPtr, outPtr, blocksPerDim, gpu) case tensor.Q5_K: err2 = cuda.DequantQ5K(rowPtr, outPtr, blocksPerDim, gpu) case tensor.Q4_K: err2 = cuda.DequantQ4K(rowPtr, outPtr, blocksPerDim, gpu) case tensor.Q2_K: err2 = cuda.DequantQ2K(rowPtr, outPtr, blocksPerDim, gpu) case tensor.Q3_K: err2 = cuda.DequantQ3K(rowPtr, outPtr, blocksPerDim, gpu) case tensor.Q6_K: err2 = cuda.DequantQ6K(rowPtr, outPtr, blocksPerDim, gpu) } if err2 == nil { return NewActivationFrom(hiddenGPU), nil } } } } } } } // Generic CPU embedding. embCPU, err := ToCPU(tokenEmb) if err != nil { return nil, err } hiddenCPU := cpu.NewTensor(tensor.Shape{seqLen, hiddenSize}, nil) if err := nn.Embedding(ids, embCPU, hiddenCPU); err != nil { return nil, err } return NewActivationFrom(hiddenCPU), nil }