token_embedding_cuda.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. //go:build cuda
  2. package compute
  3. import (
  4. "fmt"
  5. "unsafe"
  6. "makarna/pkg/backend/cpu"
  7. "makarna/pkg/backend/cpu/nn"
  8. "makarna/pkg/backend/cuda"
  9. "makarna/pkg/backend/device"
  10. "makarna/pkg/tensor"
  11. )
  12. func HybridTokenEmbedding(ids []int, tokenEmb tensor.Tensor, hiddenSize int, dispatcher *device.DeviceDispatcher) (*Activation, error) {
  13. seqLen := len(ids)
  14. // Fast-path: single-token decode, first layer on GPU, embedding weights are quantized.
  15. // We dequantize only the selected row directly on GPU to avoid CPU dequant + H2D copy.
  16. if seqLen == 1 && dispatcher != nil {
  17. p0 := dispatcher.LayerPlacement(0).Normalize()
  18. if p0.Type == tensor.CUDA && cuda.Available() {
  19. gpu := p0.GPU
  20. if gpu < 0 {
  21. gpu = 0
  22. }
  23. if embCPU, ok := tokenEmb.(*cpu.Tensor); ok {
  24. dtype := embCPU.DType()
  25. blocksPerDim := hiddenSize / 256
  26. if blocksPerDim > 0 && hiddenSize%256 == 0 {
  27. cache := GetWeightCache(gpu)
  28. wKey := fmt.Sprintf("tok_emb_%p", embCPU)
  29. gpuW, ok2 := cache.Get(wKey)
  30. var err error
  31. if !ok2 {
  32. gpuW, err = cache.Upload(wKey, embCPU)
  33. if err != nil {
  34. return nil, err
  35. }
  36. }
  37. var bytesPerBlock int
  38. switch dtype {
  39. case tensor.Q8_K:
  40. bytesPerBlock = 292
  41. case tensor.Q5_K:
  42. bytesPerBlock = 176
  43. case tensor.Q4_K:
  44. bytesPerBlock = 144
  45. case tensor.Q2_K:
  46. bytesPerBlock = 84
  47. case tensor.Q3_K:
  48. bytesPerBlock = 110
  49. case tensor.Q6_K:
  50. bytesPerBlock = 210
  51. default:
  52. bytesPerBlock = 0
  53. }
  54. if bytesPerBlock > 0 {
  55. hiddenGPU, err2 := cuda.NewTensor(tensor.Shape{1, hiddenSize}, tensor.Float32, gpu)
  56. if err2 == nil {
  57. id := ids[0]
  58. rowOff := uintptr(id * blocksPerDim * bytesPerBlock)
  59. rowPtr := unsafe.Pointer(uintptr(gpuW) + rowOff)
  60. outPtr := hiddenGPU.Data().(unsafe.Pointer)
  61. switch dtype {
  62. case tensor.Q8_K:
  63. err2 = cuda.DequantQ8K(rowPtr, outPtr, blocksPerDim, gpu)
  64. case tensor.Q5_K:
  65. err2 = cuda.DequantQ5K(rowPtr, outPtr, blocksPerDim, gpu)
  66. case tensor.Q4_K:
  67. err2 = cuda.DequantQ4K(rowPtr, outPtr, blocksPerDim, gpu)
  68. case tensor.Q2_K:
  69. err2 = cuda.DequantQ2K(rowPtr, outPtr, blocksPerDim, gpu)
  70. case tensor.Q3_K:
  71. err2 = cuda.DequantQ3K(rowPtr, outPtr, blocksPerDim, gpu)
  72. case tensor.Q6_K:
  73. err2 = cuda.DequantQ6K(rowPtr, outPtr, blocksPerDim, gpu)
  74. }
  75. if err2 == nil {
  76. return NewActivationFrom(hiddenGPU), nil
  77. }
  78. }
  79. }
  80. }
  81. }
  82. }
  83. }
  84. // Generic CPU embedding.
  85. embCPU, err := ToCPU(tokenEmb)
  86. if err != nil {
  87. return nil, err
  88. }
  89. hiddenCPU := cpu.NewTensor(tensor.Shape{seqLen, hiddenSize}, nil)
  90. if err := nn.Embedding(ids, embCPU, hiddenCPU); err != nil {
  91. return nil, err
  92. }
  93. return NewActivationFrom(hiddenCPU), nil
  94. }