batcher.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453
  1. package engine
  2. import (
  3. "context"
  4. "fmt"
  5. "sort"
  6. "sync"
  7. "time"
  8. "unsafe"
  9. "makarna/pkg/backend/cpu"
  10. "makarna/pkg/backend/cuda"
  11. "makarna/pkg/backend/device"
  12. "makarna/pkg/compute"
  13. "makarna/pkg/kvcache"
  14. "makarna/pkg/model"
  15. "makarna/pkg/sample"
  16. "makarna/pkg/tensor"
  17. )
  18. type DecodeEvent struct {
  19. Token int
  20. Done bool
  21. Err error
  22. }
  23. type DecodeSequence struct {
  24. RequestID string
  25. Ctx context.Context
  26. Cache kvcache.KVCacheInterface
  27. // History includes prompt + generated tokens so far.
  28. History []int
  29. NextInputToken int
  30. Remaining int
  31. EosID int
  32. Sampler *sample.Sampler
  33. }
  34. type Batcher struct {
  35. eng *Engine
  36. cmdCh chan any
  37. onceStart sync.Once
  38. }
  39. type registerCmd struct {
  40. seq *DecodeSequence
  41. event chan DecodeEvent
  42. resp chan error
  43. }
  44. type stopCmd struct {
  45. reqID string
  46. }
  47. type seqState struct {
  48. seq *DecodeSequence
  49. event chan DecodeEvent
  50. }
  51. func NewBatcher(eng *Engine) *Batcher {
  52. return &Batcher{
  53. eng: eng,
  54. cmdCh: make(chan any, 1024),
  55. }
  56. }
  57. func (b *Batcher) Start() {
  58. b.onceStart.Do(func() {
  59. go b.loop()
  60. })
  61. }
  62. func (b *Batcher) RegisterDecode(seq *DecodeSequence) (<-chan DecodeEvent, error) {
  63. if seq == nil {
  64. return nil, fmt.Errorf("nil sequence")
  65. }
  66. if seq.Cache == nil {
  67. return nil, fmt.Errorf("nil cache")
  68. }
  69. if seq.Sampler == nil {
  70. return nil, fmt.Errorf("nil sampler")
  71. }
  72. if seq.Ctx == nil {
  73. seq.Ctx = context.Background()
  74. }
  75. if seq.Remaining <= 0 {
  76. ch := make(chan DecodeEvent)
  77. close(ch)
  78. return ch, nil
  79. }
  80. b.Start()
  81. event := make(chan DecodeEvent, 16)
  82. resp := make(chan error, 1)
  83. b.cmdCh <- registerCmd{seq: seq, event: event, resp: resp}
  84. return event, <-resp
  85. }
  86. func (b *Batcher) Stop(reqID string) {
  87. if reqID == "" {
  88. return
  89. }
  90. b.Start()
  91. b.cmdCh <- stopCmd{reqID: reqID}
  92. }
  93. func (b *Batcher) loop() {
  94. seqs := make(map[string]*seqState)
  95. // Scratch set reused for all batch steps (supports multi-GPU).
  96. var scratchSet *compute.ScratchSet
  97. var baseScratch *compute.ScratchSpace
  98. if b.eng != nil && b.eng.Dispatcher() != nil && cuda.Available() {
  99. cfg := b.eng.Model().Config()
  100. gpus := collectDispatcherGPUs(b.eng.Dispatcher(), cfg.NumLayers)
  101. if len(gpus) > 0 {
  102. if ss, err := compute.NewScratchSet(gpus, compute.DefaultScratchBytes); err == nil {
  103. scratchSet = ss
  104. defer scratchSet.Free()
  105. baseScratch = scratchSet.Scratch(gpus[0])
  106. }
  107. }
  108. }
  109. for {
  110. // Block when idle.
  111. if len(seqs) == 0 {
  112. cmd := <-b.cmdCh
  113. switch c := cmd.(type) {
  114. case registerCmd:
  115. if c.seq.RequestID == "" {
  116. c.resp <- fmt.Errorf("missing RequestID")
  117. close(c.event)
  118. continue
  119. }
  120. seqs[c.seq.RequestID] = &seqState{seq: c.seq, event: c.event}
  121. c.resp <- nil
  122. case stopCmd:
  123. if st, ok := seqs[c.reqID]; ok {
  124. st.event <- DecodeEvent{Done: true, Err: context.Canceled}
  125. close(st.event)
  126. delete(seqs, c.reqID)
  127. }
  128. }
  129. continue
  130. }
  131. // Drain control commands without blocking.
  132. for {
  133. select {
  134. case cmd := <-b.cmdCh:
  135. switch c := cmd.(type) {
  136. case registerCmd:
  137. if c.seq.RequestID == "" {
  138. c.resp <- fmt.Errorf("missing RequestID")
  139. close(c.event)
  140. continue
  141. }
  142. seqs[c.seq.RequestID] = &seqState{seq: c.seq, event: c.event}
  143. c.resp <- nil
  144. case stopCmd:
  145. if st, ok := seqs[c.reqID]; ok {
  146. st.event <- DecodeEvent{Done: true, Err: context.Canceled}
  147. close(st.event)
  148. delete(seqs, c.reqID)
  149. }
  150. }
  151. default:
  152. goto drained
  153. }
  154. }
  155. drained:
  156. // Collect ready sequences
  157. batch := make([]*seqState, 0, len(seqs))
  158. for id, st := range seqs {
  159. if st == nil || st.seq == nil {
  160. delete(seqs, id)
  161. continue
  162. }
  163. seq := st.seq
  164. if seq.Ctx != nil {
  165. select {
  166. case <-seq.Ctx.Done():
  167. st.event <- DecodeEvent{Done: true, Err: seq.Ctx.Err()}
  168. close(st.event)
  169. delete(seqs, id)
  170. continue
  171. default:
  172. }
  173. }
  174. if seq.Remaining <= 0 {
  175. st.event <- DecodeEvent{Done: true}
  176. close(st.event)
  177. delete(seqs, id)
  178. continue
  179. }
  180. if seq.NextInputToken == seq.EosID {
  181. st.event <- DecodeEvent{Done: true}
  182. close(st.event)
  183. delete(seqs, id)
  184. continue
  185. }
  186. batch = append(batch, st)
  187. }
  188. if len(batch) == 0 {
  189. time.Sleep(50 * time.Microsecond)
  190. continue
  191. }
  192. // Build input and positions
  193. input := cpu.NewTensor(tensor.Shape{len(batch)}, nil)
  194. pos := cpu.NewTensor(tensor.Shape{len(batch)}, nil)
  195. kvCaches := make([]model.KVCache, len(batch))
  196. for i, st := range batch {
  197. seq := st.seq
  198. input.DataFloat32()[i] = float32(seq.NextInputToken)
  199. pos.DataFloat32()[i] = float32(seq.Cache.SeqLen())
  200. kvCaches[i] = seq.Cache
  201. }
  202. ctx := context.Background()
  203. if scratchSet != nil {
  204. scratchSet.Reset()
  205. ctx = compute.WithScratchSet(ctx, scratchSet)
  206. }
  207. if baseScratch != nil {
  208. baseScratch.Reset()
  209. ctx = compute.WithScratch(ctx, baseScratch)
  210. }
  211. logits, err := b.eng.ForwardBatch(ctx, input, pos, kvCaches)
  212. if err != nil {
  213. for _, st := range batch {
  214. st.event <- DecodeEvent{Done: true, Err: err}
  215. close(st.event)
  216. delete(seqs, st.seq.RequestID)
  217. }
  218. continue
  219. }
  220. vocab := logits.Shape()[1]
  221. for i, st := range batch {
  222. seq := st.seq
  223. recent := seq.History
  224. if len(recent) > 64 {
  225. recent = recent[len(recent)-64:]
  226. }
  227. next, rowErr := sampleNextTokenFromLogits(logits, i, vocab, seq.Sampler, recent)
  228. if rowErr != nil {
  229. st.event <- DecodeEvent{Done: true, Err: rowErr}
  230. close(st.event)
  231. delete(seqs, seq.RequestID)
  232. continue
  233. }
  234. seq.History = append(seq.History, next)
  235. if pc, ok := seq.Cache.(*kvcache.PagedKVCache); ok {
  236. pc.AppendToken(next)
  237. }
  238. seq.NextInputToken = next
  239. seq.Remaining--
  240. st.event <- DecodeEvent{Token: next}
  241. if next == seq.EosID || seq.Remaining <= 0 {
  242. st.event <- DecodeEvent{Done: true}
  243. close(st.event)
  244. delete(seqs, seq.RequestID)
  245. }
  246. }
  247. }
  248. }
  249. func collectDispatcherGPUs(d *device.DeviceDispatcher, numLayers int) []int {
  250. if d == nil || numLayers <= 0 {
  251. return nil
  252. }
  253. seen := make(map[int]struct{})
  254. out := make([]int, 0, 4)
  255. for i := 0; i < numLayers; i++ {
  256. p := d.LayerPlacement(i).Normalize()
  257. if p.Type != tensor.CUDA || p.GPU < 0 {
  258. continue
  259. }
  260. if _, ok := seen[p.GPU]; ok {
  261. continue
  262. }
  263. seen[p.GPU] = struct{}{}
  264. out = append(out, p.GPU)
  265. }
  266. sort.Ints(out)
  267. return out
  268. }
  269. func logitsRowCPU(logits tensor.Tensor, row int, vocab int) ([]float32, error) {
  270. if logits == nil {
  271. return nil, fmt.Errorf("nil logits")
  272. }
  273. shape := logits.Shape()
  274. if len(shape) != 2 {
  275. return nil, fmt.Errorf("expected 2D logits, got shape %v", shape)
  276. }
  277. if row < 0 || row >= shape[0] {
  278. return nil, fmt.Errorf("row %d out of range", row)
  279. }
  280. if vocab <= 0 || vocab != shape[1] {
  281. return nil, fmt.Errorf("vocab mismatch: %d vs %d", vocab, shape[1])
  282. }
  283. if cpuT, ok := logits.(*cpu.Tensor); ok {
  284. start := row * vocab
  285. end := start + vocab
  286. data := cpuT.DataFloat32()
  287. if end > len(data) {
  288. return nil, fmt.Errorf("cpu logits out of range")
  289. }
  290. out := make([]float32, vocab)
  291. copy(out, data[start:end])
  292. return out, nil
  293. }
  294. if cudaT, ok := logits.(*cuda.Tensor); ok {
  295. view, err := cudaT.ViewAt(tensor.Shape{vocab}, uintptr(row*vocab*4))
  296. if err != nil {
  297. return nil, err
  298. }
  299. host := make([]float32, vocab)
  300. if err := view.CopyToHost(host); err != nil {
  301. return nil, err
  302. }
  303. return host, nil
  304. }
  305. if p, ok := logits.Data().(unsafe.Pointer); ok && p != nil {
  306. _ = p
  307. }
  308. return nil, fmt.Errorf("unsupported logits tensor type %T", logits)
  309. }
  310. func sampleNextTokenFromLogits(logits tensor.Tensor, row int, vocab int, sampler *sample.Sampler, recent []int) (int, error) {
  311. if logits == nil {
  312. return 0, fmt.Errorf("nil logits")
  313. }
  314. if sampler == nil {
  315. return 0, fmt.Errorf("nil sampler")
  316. }
  317. if row < 0 {
  318. return 0, fmt.Errorf("row %d out of range", row)
  319. }
  320. cfg := sampler.Config()
  321. k := cfg.TopK
  322. if cfg.Temperature == 0 {
  323. k = 1
  324. }
  325. // CPU logits: zero-copy row slice.
  326. if cpuT, ok := logits.(*cpu.Tensor); ok {
  327. shape := cpuT.Shape()
  328. if len(shape) != 2 || shape[1] != vocab {
  329. return 0, fmt.Errorf("expected logits shape [*,%d], got %v", vocab, shape)
  330. }
  331. if row >= shape[0] {
  332. return 0, fmt.Errorf("row %d out of range", row)
  333. }
  334. start := row * vocab
  335. end := start + vocab
  336. data := cpuT.DataFloat32()
  337. if end > len(data) {
  338. return 0, fmt.Errorf("cpu logits out of range")
  339. }
  340. return sampler.Sample(data[start:end], recent), nil
  341. }
  342. // CUDA logits: prefer GPU top-k path when enabled/greedy and supported by kernel.
  343. if cudaT, ok := logits.(*cuda.Tensor); ok {
  344. shape := cudaT.Shape()
  345. if len(shape) != 2 || shape[1] != vocab {
  346. return 0, fmt.Errorf("expected logits shape [*,%d], got %v", vocab, shape)
  347. }
  348. if row >= shape[0] {
  349. return 0, fmt.Errorf("row %d out of range", row)
  350. }
  351. view, err := cudaT.ViewAt(tensor.Shape{vocab}, uintptr(row*vocab*4))
  352. if err != nil {
  353. return 0, err
  354. }
  355. // CUDA TopK kernel supports k<=64; fall back when disabled or too large.
  356. if k <= 0 || k > 64 {
  357. host := make([]float32, vocab)
  358. if err := view.CopyToHost(host); err != nil {
  359. return 0, err
  360. }
  361. return sampler.Sample(host, recent), nil
  362. }
  363. repPenalty := cfg.RepetitionPenalty
  364. if repPenalty <= 0 {
  365. repPenalty = 1.0
  366. }
  367. repIDs := make([]int32, len(recent))
  368. for i, t := range recent {
  369. repIDs[i] = int32(t)
  370. }
  371. allIDs, allScores, blocks, err := cuda.TopKLogitsF32(view.Data().(unsafe.Pointer), vocab, repIDs, repPenalty, k, cudaT.GPU())
  372. if err != nil {
  373. return 0, err
  374. }
  375. cands := make([]struct {
  376. id int32
  377. score float32
  378. }, 0, blocks*k)
  379. for i := 0; i < blocks*k; i++ {
  380. if allIDs[i] < 0 {
  381. continue
  382. }
  383. cands = append(cands, struct {
  384. id int32
  385. score float32
  386. }{id: allIDs[i], score: allScores[i]})
  387. }
  388. if len(cands) == 0 {
  389. host := make([]float32, vocab)
  390. if err := view.CopyToHost(host); err != nil {
  391. return 0, err
  392. }
  393. return sampler.Sample(host, recent), nil
  394. }
  395. sort.Slice(cands, func(i, j int) bool { return cands[i].score > cands[j].score })
  396. if len(cands) > k {
  397. cands = cands[:k]
  398. }
  399. finalIDs := make([]int32, len(cands))
  400. finalScores := make([]float32, len(cands))
  401. for i := range cands {
  402. finalIDs[i] = cands[i].id
  403. finalScores[i] = cands[i].score
  404. }
  405. return sampler.SampleFromTopK(finalIDs, finalScores), nil
  406. }
  407. return 0, fmt.Errorf("unsupported logits type: %T", logits)
  408. }