forward_device.go 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. // Package qwen3 implements device-aware forward pass using Hybrid* operations.
  2. // This enables efficient GPU/CPU offloading without duplicating forward pass logic.
  3. package qwen3
  4. import (
  5. "context"
  6. "fmt"
  7. "makarna/pkg/backend/cpu"
  8. "makarna/pkg/backend/cpu/nn"
  9. "makarna/pkg/compute"
  10. "makarna/pkg/kvcache"
  11. "makarna/pkg/model/arch"
  12. "makarna/pkg/model"
  13. "makarna/pkg/tensor"
  14. )
  15. // Forward performs a forward pass with automatic device placement.
  16. // If a dispatcher is provided in context, uses GPU operations for GPU layers.
  17. func (m *Model) Forward(ctx context.Context, input tensor.Tensor, positions tensor.Tensor, kvCache model.KVCache) (tensor.Tensor, error) {
  18. cfg := m.config
  19. seqLen := input.Shape()[0]
  20. headDim := cfg.HeadDim
  21. if headDim == 0 {
  22. headDim = cfg.HiddenSize / cfg.NumHeads
  23. }
  24. rmsNormEps := float32(cfg.RMSNormEps)
  25. if rmsNormEps == 0 {
  26. rmsNormEps = 1e-6
  27. }
  28. // Parse inputs
  29. ids := nn.ParseTokenIDs(input)
  30. posArr := nn.ParsePositions(positions, seqLen)
  31. dispatcher := compute.DispatcherFromContext(ctx)
  32. hidden, err := compute.HybridTokenEmbedding(ids, m.tokenEmb, cfg.HiddenSize, dispatcher)
  33. if err != nil {
  34. return nil, fmt.Errorf("embedding: %w", err)
  35. }
  36. scratchSet := compute.ScratchSetFromContext(ctx)
  37. baseScratch := compute.ScratchFromContext(ctx)
  38. // Get cache - supports both Cache and PagedKVCache via interface
  39. var cache kvcache.KVCacheInterface
  40. if kvCache != nil {
  41. cache, _ = kvCache.(kvcache.KVCacheInterface)
  42. }
  43. // Process transformer layers with device-aware operations
  44. var lastPlacement tensor.DevicePlacement = tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}
  45. for i, layer := range m.layers {
  46. // Get target device for this layer
  47. var targetPlacement tensor.DevicePlacement
  48. if dispatcher != nil {
  49. targetPlacement = dispatcher.LayerPlacement(i).Normalize()
  50. } else {
  51. targetPlacement = tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}
  52. }
  53. // Transfer activation if crossing device boundary
  54. if targetPlacement != lastPlacement {
  55. if err := compute.EnsureOnDevice(hidden, targetPlacement); err != nil {
  56. return nil, fmt.Errorf("layer %d: device transfer failed: %w", i, err)
  57. }
  58. }
  59. lastPlacement = targetPlacement
  60. var layerScratch *compute.ScratchSpace
  61. if targetPlacement.Type == tensor.CUDA {
  62. if scratchSet != nil {
  63. layerScratch = scratchSet.Scratch(targetPlacement.GPU)
  64. } else if baseScratch != nil && baseScratch.GPU() == targetPlacement.GPU {
  65. layerScratch = baseScratch
  66. }
  67. if layerScratch != nil {
  68. layerScratch.Reset()
  69. }
  70. }
  71. // Create compute context for this layer
  72. compCtx := compute.NewContext(dispatcher, i)
  73. compCtx.Scratch = layerScratch
  74. // Run transformer block with device-aware operations
  75. if err := m.transformerBlockDeviceAware(compCtx, hidden, layer, posArr, cache, cfg, headDim, rmsNormEps); err != nil {
  76. return nil, fmt.Errorf("layer %d: %w", i, err)
  77. }
  78. }
  79. // Commit KV cache
  80. if cache != nil {
  81. cache.Commit(seqLen)
  82. }
  83. // Final norm
  84. // Use the last layer's context or create a new one to enable GPU execution if hidden is on GPU.
  85. // We reuse the dispatcher logic: if hidden is on GPU, we try to keep it there.
  86. finalCtx := compute.NewContext(dispatcher, len(m.layers)-1)
  87. if lastPlacement.Type == tensor.CUDA {
  88. if scratchSet != nil {
  89. finalCtx.Scratch = scratchSet.Scratch(lastPlacement.GPU)
  90. } else if baseScratch != nil && baseScratch.GPU() == lastPlacement.GPU {
  91. finalCtx.Scratch = baseScratch
  92. }
  93. if finalCtx.Scratch != nil {
  94. finalCtx.Scratch.Reset()
  95. }
  96. }
  97. // Use HybridRMSNorm (keeps on GPU if already there)
  98. if err := compute.HybridRMSNorm(finalCtx, hidden, m.norm, rmsNormEps); err != nil {
  99. return nil, fmt.Errorf("final norm: %w", err)
  100. }
  101. // LM head
  102. // We initialize logits activation. HybridLinear will handle device placement.
  103. // If hidden is GPU, HybridLinear will execute on GPU and upgrade logits to GPU.
  104. logitsAct := compute.NewActivationFrom(
  105. cpu.NewTensor(tensor.Shape{seqLen, cfg.VocabSize}, nil),
  106. )
  107. outputWeights := m.output
  108. if outputWeights == nil {
  109. outputWeights = m.tokenEmb
  110. }
  111. if err := compute.HybridLinear(finalCtx, hidden, outputWeights, logitsAct); err != nil {
  112. return nil, fmt.Errorf("lm head: %w", err)
  113. }
  114. // Return logits on the device they were computed on.
  115. return logitsAct.Tensor(), nil
  116. }
  117. func (m *Model) ForwardBatch(ctx context.Context, input tensor.Tensor, positions tensor.Tensor, kvCaches []model.KVCache) (tensor.Tensor, error) {
  118. cfg := m.config
  119. seqLen := input.Shape()[0]
  120. if len(kvCaches) != seqLen {
  121. return nil, fmt.Errorf("kvCaches len %d != input len %d", len(kvCaches), seqLen)
  122. }
  123. headDim := cfg.HeadDim
  124. if headDim == 0 {
  125. headDim = cfg.HiddenSize / cfg.NumHeads
  126. }
  127. rmsNormEps := float32(cfg.RMSNormEps)
  128. if rmsNormEps == 0 {
  129. rmsNormEps = 1e-6
  130. }
  131. ids := nn.ParseTokenIDs(input)
  132. posArr := nn.ParsePositions(positions, seqLen)
  133. dispatcher := compute.DispatcherFromContext(ctx)
  134. hidden, err := compute.HybridTokenEmbedding(ids, m.tokenEmb, cfg.HiddenSize, dispatcher)
  135. if err != nil {
  136. return nil, fmt.Errorf("embedding: %w", err)
  137. }
  138. scratchSet := compute.ScratchSetFromContext(ctx)
  139. baseScratch := compute.ScratchFromContext(ctx)
  140. caches := make([]kvcache.KVCacheInterface, len(kvCaches))
  141. for i := range kvCaches {
  142. if kvCaches[i] == nil {
  143. caches[i] = nil
  144. continue
  145. }
  146. c, ok := kvCaches[i].(kvcache.KVCacheInterface)
  147. if !ok {
  148. return nil, fmt.Errorf("kvCache[%d] does not implement KVCacheInterface: %T", i, kvCaches[i])
  149. }
  150. caches[i] = c
  151. }
  152. var lastPlacement tensor.DevicePlacement = tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}
  153. for i, layer := range m.layers {
  154. var targetPlacement tensor.DevicePlacement
  155. if dispatcher != nil {
  156. targetPlacement = dispatcher.LayerPlacement(i).Normalize()
  157. } else {
  158. targetPlacement = tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}
  159. }
  160. if targetPlacement != lastPlacement {
  161. if err := compute.EnsureOnDevice(hidden, targetPlacement); err != nil {
  162. return nil, fmt.Errorf("layer %d: device transfer failed: %w", i, err)
  163. }
  164. }
  165. lastPlacement = targetPlacement
  166. var layerScratch *compute.ScratchSpace
  167. if targetPlacement.Type == tensor.CUDA {
  168. if scratchSet != nil {
  169. layerScratch = scratchSet.Scratch(targetPlacement.GPU)
  170. } else if baseScratch != nil && baseScratch.GPU() == targetPlacement.GPU {
  171. layerScratch = baseScratch
  172. }
  173. if layerScratch != nil {
  174. layerScratch.Reset()
  175. }
  176. }
  177. compCtx := compute.NewContext(dispatcher, i)
  178. compCtx.Scratch = layerScratch
  179. blockCfg := arch.HybridDecoderConfig{
  180. HiddenSize: cfg.HiddenSize,
  181. NumHeads: cfg.NumHeads,
  182. NumKVHeads: cfg.NumKVHeads,
  183. Intermediate: cfg.Intermediate,
  184. HeadDim: headDim,
  185. RopeTheta: float32(cfg.RopeTheta),
  186. }
  187. blockLayer := arch.HybridDecoderLayerWeights{
  188. Idx: layer.idx,
  189. AttnNorm: layer.attnNorm,
  190. Wq: layer.wq,
  191. Wk: layer.wk,
  192. Wv: layer.wv,
  193. Wo: layer.wo,
  194. QNorm: layer.qNorm,
  195. KNorm: layer.kNorm,
  196. MlpNorm: layer.mlpNorm,
  197. WGate: layer.wGate,
  198. WUp: layer.wUp,
  199. WDown: layer.wDown,
  200. }
  201. if err := arch.HybridDecoderBlockBatch(compCtx, hidden, &blockLayer, posArr, caches, blockCfg, rmsNormEps); err != nil {
  202. return nil, fmt.Errorf("layer %d: %w", i, err)
  203. }
  204. }
  205. for i := range caches {
  206. if caches[i] != nil {
  207. caches[i].Commit(1)
  208. }
  209. }
  210. finalCtx := compute.NewContext(dispatcher, len(m.layers)-1)
  211. if lastPlacement.Type == tensor.CUDA {
  212. if scratchSet != nil {
  213. finalCtx.Scratch = scratchSet.Scratch(lastPlacement.GPU)
  214. } else if baseScratch != nil && baseScratch.GPU() == lastPlacement.GPU {
  215. finalCtx.Scratch = baseScratch
  216. }
  217. if finalCtx.Scratch != nil {
  218. finalCtx.Scratch.Reset()
  219. }
  220. }
  221. if err := compute.HybridRMSNorm(finalCtx, hidden, m.norm, rmsNormEps); err != nil {
  222. return nil, fmt.Errorf("final norm: %w", err)
  223. }
  224. logitsAct := compute.NewActivationFrom(
  225. cpu.NewTensor(tensor.Shape{seqLen, cfg.VocabSize}, nil),
  226. )
  227. outputWeights := m.output
  228. if outputWeights == nil {
  229. outputWeights = m.tokenEmb
  230. }
  231. if err := compute.HybridLinear(finalCtx, hidden, outputWeights, logitsAct); err != nil {
  232. return nil, fmt.Errorf("lm head: %w", err)
  233. }
  234. return logitsAct.Tensor(), nil
  235. }
  236. // transformerBlockDeviceAware processes a single transformer layer using device-aware operations.
  237. func (m *Model) transformerBlockDeviceAware(ctx *compute.Context, hidden *compute.Activation, layer *Layer, positions []int, cache kvcache.KVCacheInterface, cfg *model.Config, headDim int, eps float32) error {
  238. blockCfg := arch.HybridDecoderConfig{
  239. HiddenSize: cfg.HiddenSize,
  240. NumHeads: cfg.NumHeads,
  241. NumKVHeads: cfg.NumKVHeads,
  242. Intermediate: cfg.Intermediate,
  243. HeadDim: headDim,
  244. RopeTheta: float32(cfg.RopeTheta),
  245. }
  246. blockLayer := arch.HybridDecoderLayerWeights{
  247. Idx: layer.idx,
  248. AttnNorm: layer.attnNorm,
  249. Wq: layer.wq,
  250. Wk: layer.wk,
  251. Wv: layer.wv,
  252. Wo: layer.wo,
  253. QNorm: layer.qNorm,
  254. KNorm: layer.kNorm,
  255. MlpNorm: layer.mlpNorm,
  256. WGate: layer.wGate,
  257. WUp: layer.wUp,
  258. WDown: layer.wDown,
  259. }
  260. return arch.HybridDecoderBlock(ctx, hidden, &blockLayer, positions, cache, blockCfg, eps)
  261. }