1
0

server.go 43 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633
  1. package openai
  2. import (
  3. "context"
  4. "errors"
  5. "encoding/json"
  6. "fmt"
  7. "log"
  8. "math/rand"
  9. "net/http"
  10. "sort"
  11. "strings"
  12. "sync"
  13. "time"
  14. "unicode/utf8"
  15. "unsafe"
  16. "makarna/pkg/backend/cpu"
  17. "makarna/pkg/backend/cuda"
  18. "makarna/pkg/backend/device"
  19. "makarna/pkg/chat"
  20. "makarna/pkg/compute"
  21. "makarna/pkg/engine"
  22. "makarna/pkg/kvcache"
  23. "makarna/pkg/model"
  24. "makarna/pkg/sample"
  25. "makarna/pkg/tensor"
  26. "makarna/pkg/tokenizer"
  27. )
  28. type ChatCompletionRequest struct {
  29. Model string `json:"model"`
  30. Messages []ChatCompletionMessage `json:"messages"`
  31. Tools []any `json:"tools,omitempty"`
  32. Stream bool `json:"stream,omitempty"`
  33. SessionID string `json:"session_id,omitempty"`
  34. MaxTokens int `json:"max_tokens,omitempty"`
  35. Temperature float64 `json:"temperature,omitempty"`
  36. TopP float64 `json:"top_p,omitempty"`
  37. TopK int `json:"top_k,omitempty"`
  38. PresencePenalty float64 `json:"presence_penalty,omitempty"`
  39. FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
  40. }
  41. type CompletionRequest struct {
  42. Model string `json:"model"`
  43. Prompt string `json:"prompt"`
  44. Stream bool `json:"stream,omitempty"`
  45. MaxTokens int `json:"max_tokens,omitempty"`
  46. Temperature float64 `json:"temperature,omitempty"`
  47. TopP float64 `json:"top_p,omitempty"`
  48. TopK int `json:"top_k,omitempty"`
  49. }
  50. type utf8StreamBuffer struct {
  51. buf []byte
  52. }
  53. func (b *utf8StreamBuffer) Push(s string) string {
  54. if s == "" {
  55. return ""
  56. }
  57. b.buf = append(b.buf, []byte(s)...)
  58. outLen := 0
  59. for outLen < len(b.buf) {
  60. _, size := utf8.DecodeRune(b.buf[outLen:])
  61. if size == 0 {
  62. break
  63. }
  64. if size == 1 {
  65. if !utf8.FullRune(b.buf[outLen:]) {
  66. break
  67. }
  68. }
  69. outLen += size
  70. }
  71. if outLen == 0 {
  72. return ""
  73. }
  74. out := string(b.buf[:outLen])
  75. b.buf = b.buf[outLen:]
  76. return out
  77. }
  78. func (b *utf8StreamBuffer) Flush() string {
  79. if len(b.buf) == 0 {
  80. return ""
  81. }
  82. if !utf8.Valid(b.buf) {
  83. b.buf = nil
  84. return ""
  85. }
  86. out := string(b.buf)
  87. b.buf = nil
  88. return out
  89. }
  90. type ChatCompletionMessage struct {
  91. Role string `json:"role"`
  92. Content string `json:"content"`
  93. Name string `json:"name,omitempty"`
  94. ToolCallID string `json:"tool_call_id,omitempty"`
  95. }
  96. type ChatCompletionResponse struct {
  97. ID string `json:"id"`
  98. Object string `json:"object"`
  99. Created int64 `json:"created"`
  100. Model string `json:"model"`
  101. Usage ChatCompletionUsage `json:"usage"`
  102. Choices []ChatCompletionChoice `json:"choices"`
  103. }
  104. type CompletionResponse struct {
  105. ID string `json:"id"`
  106. Object string `json:"object"`
  107. Created int64 `json:"created"`
  108. Model string `json:"model"`
  109. Choices []CompletionChoice `json:"choices"`
  110. Usage CompletionUsage `json:"usage"`
  111. }
  112. type CompletionChoice struct {
  113. Index int `json:"index"`
  114. Text string `json:"text"`
  115. FinishReason string `json:"finish_reason"`
  116. LogProbs any `json:"logprobs"`
  117. }
  118. type CompletionUsage struct {
  119. PromptTokens int `json:"prompt_tokens"`
  120. CompletionTokens int `json:"completion_tokens"`
  121. TotalTokens int `json:"total_tokens"`
  122. }
  123. type openAIErrorResponse struct {
  124. Error openAIError `json:"error"`
  125. }
  126. type openAIError struct {
  127. Message string `json:"message"`
  128. Type string `json:"type"`
  129. Param *string `json:"param,omitempty"`
  130. Code *string `json:"code,omitempty"`
  131. }
  132. type contextLengthExceededError struct {
  133. PromptTokens int
  134. MaxSeqLen int
  135. }
  136. func (e *contextLengthExceededError) Error() string {
  137. return fmt.Sprintf("context length exceeded: prompt_tokens=%d max_seq_len=%d", e.PromptTokens, e.MaxSeqLen)
  138. }
  139. type ChatCompletionUsage struct {
  140. PromptTokens int `json:"prompt_tokens"`
  141. CompletionTokens int `json:"completion_tokens"`
  142. TotalTokens int `json:"total_tokens"`
  143. }
  144. type ChatCompletionChunkResponse struct {
  145. ID string `json:"id"`
  146. Object string `json:"object"`
  147. Created int64 `json:"created"`
  148. Model string `json:"model"`
  149. Choices []ChatCompletionChunkChoice `json:"choices"`
  150. }
  151. type ChatCompletionChunkChoice struct {
  152. Index int `json:"index"`
  153. Delta ChatDelta `json:"delta"`
  154. FinishReason *string `json:"finish_reason"`
  155. }
  156. type ChatDelta struct {
  157. Role string `json:"role,omitempty"`
  158. Content string `json:"content,omitempty"`
  159. }
  160. type ChatCompletionChoice struct {
  161. Index int `json:"index"`
  162. Message ChatCompletionOutMsg `json:"message"`
  163. FinishReason string `json:"finish_reason"`
  164. }
  165. type ChatCompletionOutMsg struct {
  166. Role string `json:"role"`
  167. Content string `json:"content"`
  168. ToolCalls []OpenAIToolCall `json:"tool_calls,omitempty"`
  169. }
  170. type OpenAIToolCall struct {
  171. ID string `json:"id"`
  172. Type string `json:"type"`
  173. Function OpenAIFunctionCall `json:"function"`
  174. }
  175. type OpenAIFunctionCall struct {
  176. Name string `json:"name"`
  177. Arguments string `json:"arguments"`
  178. }
  179. type Config struct {
  180. Listen string
  181. MaxSeqLen int
  182. BlockSize int
  183. KVCacheCPU bool
  184. EnableThinking bool
  185. PrefillChunkSize int
  186. MaxConcurrent int
  187. }
  188. type Server struct {
  189. eng *engine.Engine
  190. batcher *engine.Batcher
  191. tok *tokenizer.Tokenizer
  192. arch string
  193. modelID string
  194. maxSeqLen int
  195. blockSize int
  196. prefillChunkSize int
  197. mu sync.Mutex
  198. // global block pool for prefix caching (used by standard attention models)
  199. blockPool *kvcache.BlockPool
  200. blockPoolDevice tensor.DeviceType
  201. blockPoolGPU int
  202. // cacheFactory is set if model implements model.CacheFactory (e.g., KimiLinear, Mamba)
  203. // When set, we use the model's cache instead of PagedKVCache
  204. cacheFactory model.CacheFactory
  205. scratchPool chan *compute.ScratchSet
  206. scratchGPUs []int
  207. scratchBytes int
  208. enableThinking bool
  209. stripTokens []string
  210. }
  211. func Serve(eng *engine.Engine, tok *tokenizer.Tokenizer, arch string, cfg Config) error {
  212. if cfg.Listen == "" {
  213. cfg.Listen = ":8080"
  214. }
  215. if cfg.MaxSeqLen <= 0 {
  216. cfg.MaxSeqLen = 8192
  217. }
  218. if cfg.BlockSize <= 0 {
  219. cfg.BlockSize = 32
  220. }
  221. if cfg.PrefillChunkSize <= 0 {
  222. cfg.PrefillChunkSize = 512
  223. }
  224. if cfg.MaxConcurrent <= 0 {
  225. cfg.MaxConcurrent = 4
  226. }
  227. s := &Server{
  228. eng: eng,
  229. batcher: engine.NewBatcher(eng),
  230. tok: tok,
  231. arch: arch,
  232. modelID: arch,
  233. maxSeqLen: cfg.MaxSeqLen,
  234. blockSize: cfg.BlockSize,
  235. enableThinking: cfg.EnableThinking,
  236. prefillChunkSize: cfg.PrefillChunkSize,
  237. }
  238. // Check if model implements CacheFactory (e.g., KimiLinear, Mamba)
  239. if cf, ok := eng.Model().(model.CacheFactory); ok {
  240. s.cacheFactory = cf
  241. log.Printf("kv_cache: using model's CacheFactory (type=%d)", cf.CacheType())
  242. }
  243. // Only start batcher for models using PagedKVCache (standard attention)
  244. // Recurrent models don't support batching yet
  245. if s.batcher != nil && s.cacheFactory == nil {
  246. s.batcher.Start()
  247. } else {
  248. s.batcher = nil
  249. }
  250. // Initialize global BlockPool for prefix caching (only for standard attention models)
  251. modelCfg := eng.Model().Config()
  252. var (
  253. layerPlacements []tensor.DevicePlacement
  254. anyCUDALayer bool
  255. )
  256. if eng.Dispatcher() != nil {
  257. layerPlacements = make([]tensor.DevicePlacement, modelCfg.NumLayers)
  258. for i := 0; i < modelCfg.NumLayers; i++ {
  259. p := eng.Dispatcher().LayerPlacement(i).Normalize()
  260. layerPlacements[i] = p
  261. if p.Type == tensor.CUDA && p.GPU >= 0 {
  262. anyCUDALayer = true
  263. }
  264. }
  265. }
  266. // Skip BlockPool for models using CacheFactory (recurrent state models)
  267. if s.cacheFactory == nil {
  268. blocksPerRequest := (cfg.MaxSeqLen + cfg.BlockSize - 1) / cfg.BlockSize
  269. blocksPerLayer := blocksPerRequest * cfg.MaxConcurrent
  270. if blocksPerLayer <= 0 {
  271. return fmt.Errorf("invalid KV pool sizing: max_seq_len=%d block_size=%d max_concurrent=%d", cfg.MaxSeqLen, cfg.BlockSize, cfg.MaxConcurrent)
  272. }
  273. // Support mixed CPU/GPU offload: allocate KV blocks per-layer on the same device as that layer.
  274. // This avoids the pathological "KV on CPU forces attention on CPU" slowdown when not all layers fit on GPU.
  275. if !cfg.KVCacheCPU && device.CUDAAvailable() && anyCUDALayer {
  276. blockPool, err := kvcache.NewBlockPool(kvcache.BlockPoolConfig{
  277. NumLayers: modelCfg.NumLayers,
  278. NumKVHeads: modelCfg.NumKVHeads,
  279. HeadDim: modelCfg.HeadDim,
  280. BlockSize: cfg.BlockSize,
  281. NumBlocks: blocksPerLayer,
  282. Device: tensor.CUDA, // default; LayerPlacements can still contain CPU layers
  283. GPU: 0,
  284. LayerPlacements: layerPlacements,
  285. Preallocate: true,
  286. })
  287. if err != nil {
  288. return fmt.Errorf("KV cache reserve failed (max_seq_len=%d max_concurrent=%d): %w", cfg.MaxSeqLen, cfg.MaxConcurrent, err)
  289. }
  290. s.blockPool = blockPool
  291. s.blockPoolDevice = tensor.CUDA
  292. s.blockPoolGPU = 0
  293. log.Printf("kv_cache: device=mixed block_size=%d blocks_per_layer=%d", cfg.BlockSize, blocksPerLayer)
  294. } else {
  295. blockPool, err := kvcache.NewBlockPool(kvcache.BlockPoolConfig{
  296. NumLayers: modelCfg.NumLayers,
  297. NumKVHeads: modelCfg.NumKVHeads,
  298. HeadDim: modelCfg.HeadDim,
  299. BlockSize: cfg.BlockSize,
  300. NumBlocks: blocksPerLayer,
  301. Device: tensor.CPU,
  302. GPU: 0,
  303. })
  304. if err != nil {
  305. return fmt.Errorf("KV cache init failed: %w", err)
  306. }
  307. s.blockPool = blockPool
  308. s.blockPoolDevice = tensor.CPU
  309. s.blockPoolGPU = 0
  310. log.Printf("kv_cache: device=cpu block_size=%d blocks_per_layer=%d", cfg.BlockSize, blocksPerLayer)
  311. }
  312. }
  313. // Allocate scratch sets up-front so prefill/decode doesn't trigger cudaMalloc churn.
  314. if device.CUDAAvailable() && eng != nil && eng.Dispatcher() != nil {
  315. gpus := collectDispatcherGPUs(eng.Dispatcher(), modelCfg.NumLayers)
  316. if len(gpus) > 0 {
  317. seqLen := cfg.PrefillChunkSize
  318. if seqLen <= 0 {
  319. seqLen = 512
  320. }
  321. if cfg.MaxSeqLen > 0 && seqLen > cfg.MaxSeqLen {
  322. seqLen = cfg.MaxSeqLen
  323. }
  324. scratchBytes := estimateScratchBytes(modelCfg, seqLen)
  325. poolSize := cfg.MaxConcurrent
  326. if poolSize <= 0 {
  327. poolSize = 1
  328. }
  329. const minScratchBytes = 8 << 20
  330. var scratchErr error
  331. for scratchBytes >= minScratchBytes {
  332. created := make([]*compute.ScratchSet, 0, poolSize)
  333. ok := true
  334. for i := 0; i < poolSize; i++ {
  335. ss, err := compute.NewScratchSet(gpus, scratchBytes)
  336. if err != nil {
  337. scratchErr = err
  338. ok = false
  339. break
  340. }
  341. created = append(created, ss)
  342. }
  343. if ok {
  344. s.scratchPool = make(chan *compute.ScratchSet, poolSize)
  345. for _, ss := range created {
  346. s.scratchPool <- ss
  347. }
  348. s.scratchGPUs = gpus
  349. s.scratchBytes = scratchBytes
  350. log.Printf("scratch: gpus=%v bytes=%d sets=%d", gpus, scratchBytes, poolSize)
  351. break
  352. }
  353. for _, prev := range created {
  354. prev.Free()
  355. }
  356. scratchBytes /= 2
  357. }
  358. if s.scratchPool == nil && scratchErr != nil {
  359. log.Printf("scratch disabled (reserve failed): %v", scratchErr)
  360. }
  361. }
  362. }
  363. // Warm up GPU weight caches so the first request doesn't spend seconds uploading weights.
  364. // Run asynchronously so the server can start listening immediately (important for OpenWebUI discovery).
  365. if s.scratchPool != nil && s.eng != nil {
  366. go func() {
  367. ss := <-s.scratchPool
  368. defer func() { s.scratchPool <- ss }()
  369. warmCtx := context.Background()
  370. warmCtx = compute.WithScratchSet(warmCtx, ss)
  371. if len(s.scratchGPUs) > 0 {
  372. if sc := ss.Scratch(s.scratchGPUs[0]); sc != nil {
  373. warmCtx = compute.WithScratch(warmCtx, sc)
  374. }
  375. }
  376. warmSeqLen := cfg.PrefillChunkSize
  377. if warmSeqLen <= 0 {
  378. warmSeqLen = 512
  379. }
  380. if cfg.MaxSeqLen > 0 && warmSeqLen > cfg.MaxSeqLen {
  381. warmSeqLen = cfg.MaxSeqLen
  382. }
  383. warmIDs := make([]int, warmSeqLen)
  384. _, err := s.eng.Forward(warmCtx, createInputTensor(warmIDs), createPositionTensor(0, warmSeqLen), nil)
  385. if err != nil {
  386. log.Printf("warmup forward failed: %v", err)
  387. return
  388. }
  389. log.Printf("warmup forward ok seq_len=%d", warmSeqLen)
  390. compute.LogWeightCacheSummary()
  391. }()
  392. }
  393. if tok != nil {
  394. for _, st := range tok.AddedTokenStrings() {
  395. if strings.Contains(strings.ToLower(st), "think") {
  396. continue
  397. }
  398. s.stripTokens = append(s.stripTokens, st)
  399. }
  400. }
  401. mux := http.NewServeMux()
  402. mux.HandleFunc("/v1/chat/completions", s.handleChatCompletions)
  403. mux.HandleFunc("/v1/completions", s.handleCompletions)
  404. mux.HandleFunc("/v1/models", s.handleModels)
  405. mux.HandleFunc("/", s.handleRoot)
  406. h := withCORS(mux)
  407. log.Printf("listening on %s (arch=%s, cuda=%v)", cfg.Listen, arch, device.CUDAAvailable())
  408. return http.ListenAndServe(cfg.Listen, h)
  409. }
  410. func withCORS(next http.Handler) http.Handler {
  411. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  412. w.Header().Set("Access-Control-Allow-Origin", "*")
  413. w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
  414. w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
  415. if r.Method == http.MethodOptions {
  416. w.WriteHeader(http.StatusNoContent)
  417. return
  418. }
  419. next.ServeHTTP(w, r)
  420. })
  421. }
  422. func (s *Server) handleModels(w http.ResponseWriter, r *http.Request) {
  423. if r.Method != http.MethodGet {
  424. http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
  425. return
  426. }
  427. resp := map[string]any{
  428. "object": "list",
  429. "data": []any{map[string]any{
  430. "id": s.modelID,
  431. "object": "model",
  432. "owned_by": "local",
  433. }},
  434. }
  435. writeJSON(w, resp)
  436. }
  437. func (s *Server) handleRoot(w http.ResponseWriter, r *http.Request) {
  438. if r.URL.Path != "/" {
  439. http.NotFound(w, r)
  440. return
  441. }
  442. if r.Method != http.MethodGet {
  443. http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
  444. return
  445. }
  446. writeJSON(w, map[string]any{
  447. "status": "ok",
  448. "api": map[string]any{
  449. "chat_completions": "/v1/chat/completions",
  450. "completions": "/v1/completions",
  451. "models": "/v1/models",
  452. },
  453. })
  454. }
  455. func (s *Server) handleCompletions(w http.ResponseWriter, r *http.Request) {
  456. if r.Method != http.MethodPost {
  457. http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
  458. return
  459. }
  460. start := time.Now()
  461. var req CompletionRequest
  462. if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
  463. http.Error(w, "bad json", http.StatusBadRequest)
  464. return
  465. }
  466. if req.Model == "" {
  467. req.Model = s.modelID
  468. }
  469. if req.Prompt == "" {
  470. http.Error(w, "prompt required", http.StatusBadRequest)
  471. return
  472. }
  473. if req.Stream {
  474. http.Error(w, "stream not implemented", http.StatusNotImplemented)
  475. return
  476. }
  477. maxTokens := req.MaxTokens
  478. if maxTokens <= 0 {
  479. maxTokens = 128
  480. }
  481. temp := req.Temperature
  482. if temp == 0 {
  483. temp = 0.7
  484. }
  485. topP := req.TopP
  486. if topP == 0 {
  487. topP = 0.9
  488. }
  489. topK := req.TopK
  490. if topK == 0 {
  491. topK = 40
  492. }
  493. requestID := fmt.Sprintf("cmpl_%d", rand.Int63())
  494. log.Printf("/v1/completions id=%s stream=false remote=%s model=%s max_tokens=%d", requestID, r.RemoteAddr, req.Model, maxTokens)
  495. outText, genErr := s.generate(r.Context(), requestID, req.Prompt, maxTokens, temp, topP, topK)
  496. if genErr != nil {
  497. log.Printf("/v1/completions id=%s error=%v", requestID, genErr)
  498. http.Error(w, fmt.Sprintf("generate: %v", genErr), http.StatusInternalServerError)
  499. return
  500. }
  501. promptTokens := 0
  502. completionTokens := 0
  503. if s.tok != nil {
  504. promptTokens = len(s.tok.Encode(req.Prompt))
  505. completionTokens = len(s.tok.Encode(outText))
  506. }
  507. resp := CompletionResponse{
  508. ID: requestID,
  509. Object: "text_completion",
  510. Created: time.Now().Unix(),
  511. Model: req.Model,
  512. Choices: []CompletionChoice{{
  513. Index: 0,
  514. Text: outText,
  515. FinishReason: "stop",
  516. LogProbs: nil,
  517. }},
  518. Usage: CompletionUsage{
  519. PromptTokens: promptTokens,
  520. CompletionTokens: completionTokens,
  521. TotalTokens: promptTokens + completionTokens,
  522. },
  523. }
  524. log.Printf("/v1/completions id=%s done prompt_tokens=%d completion_tokens=%d ms=%d", requestID, promptTokens, completionTokens, time.Since(start).Milliseconds())
  525. writeJSON(w, resp)
  526. }
  527. func (s *Server) handleChatCompletions(w http.ResponseWriter, r *http.Request) {
  528. if r.Method != http.MethodPost {
  529. http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
  530. return
  531. }
  532. start := time.Now()
  533. var req ChatCompletionRequest
  534. if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
  535. http.Error(w, "bad json", http.StatusBadRequest)
  536. return
  537. }
  538. if req.Model == "" {
  539. req.Model = s.modelID
  540. }
  541. if len(req.Messages) == 0 {
  542. http.Error(w, "messages required", http.StatusBadRequest)
  543. return
  544. }
  545. if req.Stream {
  546. s.handleChatCompletionsStream(w, r, req)
  547. return
  548. }
  549. requestID := fmt.Sprintf("chatcmpl_%d", rand.Int63())
  550. msgs := make([]chat.Message, 0, len(req.Messages))
  551. for _, m := range req.Messages {
  552. role := strings.ToLower(m.Role)
  553. msgs = append(msgs, chat.Message{Role: role, Content: m.Content})
  554. }
  555. prompt, promptTokens, truncated, err := s.renderPromptWithinBudget(msgs, req.Tools)
  556. if err != nil {
  557. var cl *contextLengthExceededError
  558. if errors.As(err, &cl) {
  559. writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("This model's maximum context length is %d tokens. However, your messages resulted in %d tokens.", cl.MaxSeqLen, cl.PromptTokens), "context_length_exceeded", "messages")
  560. return
  561. }
  562. http.Error(w, fmt.Sprintf("render prompt: %v", err), http.StatusInternalServerError)
  563. return
  564. }
  565. if truncated {
  566. log.Printf("/v1/chat/completions id=%s prompt_truncated=true", requestID)
  567. }
  568. maxTokens := s.computeMaxTokens(req.MaxTokens, promptTokens)
  569. log.Printf("/v1/chat/completions id=%s stream=false remote=%s model=%s messages=%d max_tokens=%d prompt_tokens=%d", requestID, r.RemoteAddr, req.Model, len(req.Messages), maxTokens, promptTokens)
  570. temp := req.Temperature
  571. if temp == 0 {
  572. temp = 0.7
  573. }
  574. topP := req.TopP
  575. if topP == 0 {
  576. topP = 0.9
  577. }
  578. topK := req.TopK
  579. if topK == 0 {
  580. topK = 40
  581. }
  582. outText, genErr := s.generate(r.Context(), requestID, prompt, maxTokens, temp, topP, topK)
  583. if genErr != nil {
  584. log.Printf("/v1/chat/completions id=%s error=%v", requestID, genErr)
  585. http.Error(w, fmt.Sprintf("generate: %v", genErr), http.StatusInternalServerError)
  586. return
  587. }
  588. genTokens := 0
  589. if s.tok != nil {
  590. genTokens = len(s.tok.Encode(outText))
  591. }
  592. content := strings.TrimSpace(outText)
  593. content, calls, err := chat.ExtractToolCalls(content)
  594. if err != nil {
  595. http.Error(w, fmt.Sprintf("parse tool_calls: %v", err), http.StatusInternalServerError)
  596. return
  597. }
  598. content = strings.TrimSpace(content)
  599. content = sanitizeAssistantContent(content, s.stripTokens)
  600. content = strings.TrimSpace(content)
  601. var outCalls []OpenAIToolCall
  602. for i, c := range calls {
  603. outCalls = append(outCalls, OpenAIToolCall{
  604. ID: fmt.Sprintf("call_%d_%d", time.Now().UnixNano(), i),
  605. Type: "function",
  606. Function: OpenAIFunctionCall{
  607. Name: c.Name,
  608. Arguments: string(bytesOrEmptyObject(c.Arguments)),
  609. },
  610. })
  611. }
  612. finish := "stop"
  613. if len(outCalls) > 0 {
  614. finish = "tool_calls"
  615. }
  616. log.Printf("/v1/chat/completions id=%s done stream=false prompt_tokens=%d gen_tokens=%d tool_calls=%d ms=%d", requestID, promptTokens, genTokens, len(outCalls), time.Since(start).Milliseconds())
  617. resp := ChatCompletionResponse{
  618. ID: requestID,
  619. Object: "chat.completion",
  620. Created: time.Now().Unix(),
  621. Model: req.Model,
  622. Usage: ChatCompletionUsage{
  623. PromptTokens: promptTokens,
  624. CompletionTokens: genTokens,
  625. TotalTokens: promptTokens + genTokens,
  626. },
  627. Choices: []ChatCompletionChoice{{
  628. Index: 0,
  629. Message: ChatCompletionOutMsg{
  630. Role: "assistant",
  631. Content: content,
  632. ToolCalls: outCalls,
  633. },
  634. FinishReason: finish,
  635. }},
  636. }
  637. writeJSON(w, resp)
  638. }
  639. func (s *Server) handleChatCompletionsStream(w http.ResponseWriter, r *http.Request, req ChatCompletionRequest) {
  640. flusher, ok := w.(http.Flusher)
  641. if !ok {
  642. http.Error(w, "streaming unsupported", http.StatusInternalServerError)
  643. return
  644. }
  645. start := time.Now()
  646. msg := make([]chat.Message, 0, len(req.Messages))
  647. for _, m := range req.Messages {
  648. role := strings.ToLower(m.Role)
  649. msg = append(msg, chat.Message{Role: role, Content: m.Content})
  650. }
  651. id := fmt.Sprintf("chatcmpl_%d", rand.Int63())
  652. created := time.Now().Unix()
  653. prompt, promptTokens, truncated, err := s.renderPromptWithinBudget(msg, req.Tools)
  654. if err != nil {
  655. var cl *contextLengthExceededError
  656. if errors.As(err, &cl) {
  657. writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("This model's maximum context length is %d tokens. However, your messages resulted in %d tokens.", cl.MaxSeqLen, cl.PromptTokens), "context_length_exceeded", "messages")
  658. return
  659. }
  660. http.Error(w, fmt.Sprintf("render prompt: %v", err), http.StatusInternalServerError)
  661. return
  662. }
  663. w.Header().Set("Content-Type", "text/event-stream")
  664. w.Header().Set("Cache-Control", "no-cache")
  665. w.Header().Set("Connection", "keep-alive")
  666. maxTokens := s.computeMaxTokens(req.MaxTokens, promptTokens)
  667. temp := req.Temperature
  668. if temp == 0 {
  669. temp = 0.7
  670. }
  671. topP := req.TopP
  672. if topP == 0 {
  673. topP = 0.9
  674. }
  675. topK := req.TopK
  676. if topK == 0 {
  677. topK = 40
  678. }
  679. if truncated {
  680. log.Printf("/v1/chat/completions id=%s prompt_truncated=true", id)
  681. }
  682. log.Printf("/v1/chat/completions id=%s stream=true remote=%s model=%s messages=%d max_tokens=%d prompt_tokens=%d", id, r.RemoteAddr, req.Model, len(req.Messages), maxTokens, promptTokens)
  683. if err := writeSSEJSON(w, ChatCompletionChunkResponse{
  684. ID: id,
  685. Object: "chat.completion.chunk",
  686. Created: created,
  687. Model: req.Model,
  688. Choices: []ChatCompletionChunkChoice{{
  689. Index: 0,
  690. Delta: ChatDelta{Role: "assistant"},
  691. }},
  692. }); err != nil {
  693. return
  694. }
  695. flusher.Flush()
  696. genTokens := 0
  697. utf8buf := &utf8StreamBuffer{}
  698. _, genErr := s.generateStream(r.Context(), id, prompt, maxTokens, temp, topP, topK, func(piece string) error {
  699. genTokens++
  700. // Decode can emit partial UTF-8 sequences across tokens.
  701. // Buffer first, then sanitize, otherwise sanitize may corrupt multibyte sequences.
  702. piece = utf8buf.Push(piece)
  703. if piece == "" {
  704. return nil
  705. }
  706. piece = sanitizeAssistantContent(piece, s.stripTokens)
  707. if err := writeSSEJSON(w, ChatCompletionChunkResponse{
  708. ID: id,
  709. Object: "chat.completion.chunk",
  710. Created: created,
  711. Model: req.Model,
  712. Choices: []ChatCompletionChunkChoice{{
  713. Index: 0,
  714. Delta: ChatDelta{Content: piece},
  715. }},
  716. }); err != nil {
  717. return err
  718. }
  719. flusher.Flush()
  720. return nil
  721. })
  722. if genErr != nil {
  723. log.Printf("/v1/chat/completions id=%s error=%v", id, genErr)
  724. if errors.Is(genErr, context.Canceled) || errors.Is(r.Context().Err(), context.Canceled) {
  725. finish := "cancelled"
  726. _ = writeSSEJSON(w, ChatCompletionChunkResponse{
  727. ID: id,
  728. Object: "chat.completion.chunk",
  729. Created: created,
  730. Model: req.Model,
  731. Choices: []ChatCompletionChunkChoice{{
  732. Index: 0,
  733. Delta: ChatDelta{},
  734. FinishReason: &finish,
  735. }},
  736. })
  737. flusher.Flush()
  738. _, _ = fmt.Fprint(w, "data: [DONE]\n\n")
  739. flusher.Flush()
  740. }
  741. return
  742. }
  743. if tail := utf8buf.Flush(); tail != "" {
  744. tail = sanitizeAssistantContent(tail, s.stripTokens)
  745. _ = writeSSEJSON(w, ChatCompletionChunkResponse{
  746. ID: id,
  747. Object: "chat.completion.chunk",
  748. Created: created,
  749. Model: req.Model,
  750. Choices: []ChatCompletionChunkChoice{{
  751. Index: 0,
  752. Delta: ChatDelta{Content: tail},
  753. }},
  754. })
  755. flusher.Flush()
  756. }
  757. finish := "stop"
  758. if err := writeSSEJSON(w, ChatCompletionChunkResponse{
  759. ID: id,
  760. Object: "chat.completion.chunk",
  761. Created: created,
  762. Model: req.Model,
  763. Choices: []ChatCompletionChunkChoice{{
  764. Index: 0,
  765. Delta: ChatDelta{},
  766. FinishReason: &finish,
  767. }},
  768. }); err != nil {
  769. return
  770. }
  771. flusher.Flush()
  772. _, _ = fmt.Fprint(w, "data: [DONE]\n\n")
  773. flusher.Flush()
  774. log.Printf("/v1/chat/completions id=%s done stream=true prompt_tokens=%d gen_tokens=%d ms=%d", id, promptTokens, genTokens, time.Since(start).Milliseconds())
  775. }
  776. func bytesOrEmptyObject(b []byte) []byte {
  777. if len(b) == 0 {
  778. return []byte("{}")
  779. }
  780. return b
  781. }
  782. func sanitizeAssistantContent(s string, stripTokens []string) string {
  783. if s == "" {
  784. return s
  785. }
  786. for _, tok := range stripTokens {
  787. if tok == "" {
  788. continue
  789. }
  790. s = strings.ReplaceAll(s, tok, "")
  791. }
  792. return s
  793. }
  794. func (s *Server) computeMaxTokens(reqMaxTokens int, promptTokens int) int {
  795. budget := s.maxSeqLen - promptTokens
  796. if budget < 1 {
  797. budget = 1
  798. }
  799. if reqMaxTokens <= 0 {
  800. return budget
  801. }
  802. if reqMaxTokens < 1 {
  803. return 1
  804. }
  805. if reqMaxTokens > budget {
  806. return budget
  807. }
  808. return reqMaxTokens
  809. }
  810. func (s *Server) renderPromptWithinBudget(msgs []chat.Message, tools []any) (string, int, bool, error) {
  811. p, err := chat.RenderForArchitecture(s.arch, msgs, chat.Options{
  812. AddGenerationPrompt: true,
  813. EnableThinking: s.enableThinking,
  814. Tools: tools,
  815. })
  816. if err != nil {
  817. return "", 0, false, err
  818. }
  819. if s.tok == nil {
  820. return p, 0, false, nil
  821. }
  822. pt := len(s.tok.Encode(p))
  823. if pt > s.maxSeqLen {
  824. return "", pt, false, &contextLengthExceededError{PromptTokens: pt, MaxSeqLen: s.maxSeqLen}
  825. }
  826. return p, pt, false, nil
  827. }
  828. func writeOpenAIError(w http.ResponseWriter, status int, message string, code string, param string) {
  829. w.Header().Set("Content-Type", "application/json")
  830. w.WriteHeader(status)
  831. typeStr := "invalid_request_error"
  832. paramCopy := param
  833. codeCopy := code
  834. _ = json.NewEncoder(w).Encode(openAIErrorResponse{
  835. Error: openAIError{
  836. Message: message,
  837. Type: typeStr,
  838. Param: &paramCopy,
  839. Code: &codeCopy,
  840. },
  841. })
  842. }
  843. func writeJSON(w http.ResponseWriter, v any) {
  844. w.Header().Set("Content-Type", "application/json")
  845. enc := json.NewEncoder(w)
  846. enc.SetEscapeHTML(false)
  847. _ = enc.Encode(v)
  848. }
  849. func writeSSEJSON(w http.ResponseWriter, v any) error {
  850. b, err := json.Marshal(v)
  851. if err != nil {
  852. return err
  853. }
  854. _, err = fmt.Fprintf(w, "data: %s\n\n", b)
  855. return err
  856. }
  857. func collectDispatcherGPUs(d *device.DeviceDispatcher, numLayers int) []int {
  858. if d == nil || numLayers <= 0 {
  859. return nil
  860. }
  861. seen := make(map[int]struct{})
  862. out := make([]int, 0, 4)
  863. for i := 0; i < numLayers; i++ {
  864. p := d.LayerPlacement(i).Normalize()
  865. if p.Type != tensor.CUDA || p.GPU < 0 {
  866. continue
  867. }
  868. if _, ok := seen[p.GPU]; ok {
  869. continue
  870. }
  871. seen[p.GPU] = struct{}{}
  872. out = append(out, p.GPU)
  873. }
  874. sort.Ints(out)
  875. return out
  876. }
  877. func estimateScratchBytes(cfg *model.Config, seqLen int) int {
  878. if cfg == nil || seqLen <= 0 {
  879. return compute.DefaultScratchBytes
  880. }
  881. hiddenSize := cfg.HiddenSize
  882. numHeads := cfg.NumHeads
  883. numKVHeads := cfg.NumKVHeads
  884. if numKVHeads == 0 {
  885. numKVHeads = numHeads
  886. }
  887. headDim := cfg.HeadDim
  888. if headDim == 0 && numHeads > 0 {
  889. headDim = hiddenSize / numHeads
  890. }
  891. intermediate := cfg.Intermediate
  892. if intermediate == 0 {
  893. intermediate = hiddenSize * 4
  894. }
  895. qDim := numHeads * headDim
  896. kvDim := numKVHeads * headDim
  897. add := func(rows, cols int) int64 {
  898. if rows <= 0 || cols <= 0 {
  899. return 0
  900. }
  901. return int64(rows) * int64(cols) * 4
  902. }
  903. var bytes int64
  904. // Attention scratch.
  905. bytes += add(seqLen, qDim) // qOut
  906. bytes += add(seqLen, kvDim) // kOut
  907. bytes += add(seqLen, kvDim) // vOut
  908. bytes += add(seqLen, qDim) // attnOut
  909. bytes += add(seqLen, hiddenSize) // attnProj
  910. // MLP scratch.
  911. bytes += add(seqLen, intermediate) // gate
  912. bytes += add(seqLen, intermediate) // up
  913. bytes += add(seqLen, intermediate) // act
  914. bytes += add(seqLen, hiddenSize) // mlpOut
  915. // Extra space for ptr tables / int32 slices / alignment slack.
  916. bytes += 8 << 20
  917. // Round up to 16MB.
  918. const align = int64(16 << 20)
  919. if rem := bytes % align; rem != 0 {
  920. bytes += align - rem
  921. }
  922. if bytes < int64(compute.DefaultScratchBytes) {
  923. bytes = int64(compute.DefaultScratchBytes)
  924. }
  925. maxInt := int64(int(^uint(0) >> 1))
  926. if bytes > maxInt {
  927. bytes = maxInt
  928. }
  929. return int(bytes)
  930. }
  931. func (s *Server) acquireScratchSet(ctx context.Context) (*compute.ScratchSet, func(), error) {
  932. if s == nil || s.scratchPool == nil {
  933. return nil, func() {}, nil
  934. }
  935. if ctx == nil {
  936. ctx = context.Background()
  937. }
  938. select {
  939. case ss := <-s.scratchPool:
  940. if ss != nil {
  941. ss.Reset()
  942. }
  943. release := func() {
  944. if ss != nil {
  945. ss.Reset()
  946. }
  947. s.scratchPool <- ss
  948. }
  949. return ss, release, nil
  950. case <-ctx.Done():
  951. return nil, nil, ctx.Err()
  952. }
  953. }
  954. func (s *Server) prefill(ctx context.Context, cache kvcache.KVCacheInterface, ids []int) (tensor.Tensor, error) {
  955. if len(ids) == 0 {
  956. return nil, fmt.Errorf("no prompt tokens to prefill")
  957. }
  958. chunk := s.prefillChunkSize
  959. if chunk <= 0 {
  960. chunk = 512
  961. }
  962. var logits tensor.Tensor
  963. for start := 0; start < len(ids); start += chunk {
  964. end := start + chunk
  965. if end > len(ids) {
  966. end = len(ids)
  967. }
  968. part := ids[start:end]
  969. input := createInputTensor(part)
  970. positions := createPositionTensor(cache.SeqLen(), len(part))
  971. before := cache.SeqLen()
  972. out, err := s.eng.Forward(ctx, input, positions, cache)
  973. if err != nil {
  974. return nil, err
  975. }
  976. logits = out
  977. // Some model implementations advance the cache internally.
  978. // If the cache didn't advance, commit here.
  979. if cache.SeqLen() == before {
  980. cache.Commit(len(part))
  981. }
  982. }
  983. return logits, nil
  984. }
  985. // prefillWithModelCache processes all tokens at once using the model's custom cache (for recurrent models).
  986. // Unlike prefill which chunks tokens, recurrent models like KimiLinear need to process sequentially
  987. // to properly update their recurrent state.
  988. func (s *Server) prefillWithModelCache(ctx context.Context, cache model.KVCache, ids []int) (tensor.Tensor, error) {
  989. if len(ids) == 0 {
  990. return nil, fmt.Errorf("no prompt tokens to prefill")
  991. }
  992. input := createInputTensor(ids)
  993. positions := createPositionTensor(cache.SeqLen(), len(ids))
  994. before := cache.SeqLen()
  995. logits, err := s.eng.Forward(ctx, input, positions, cache)
  996. if err != nil {
  997. return nil, err
  998. }
  999. // Commit if cache didn't advance internally
  1000. if cache.SeqLen() == before {
  1001. cache.Commit(len(ids))
  1002. }
  1003. return logits, nil
  1004. }
  1005. func (s *Server) generate(ctx context.Context, requestID string, prompt string, maxTokens int, temperature float64, topP float64, topK int) (string, error) {
  1006. scratchSet, releaseScratch, err := s.acquireScratchSet(ctx)
  1007. if err != nil {
  1008. return "", err
  1009. }
  1010. if scratchSet != nil {
  1011. defer releaseScratch()
  1012. ctx = compute.WithScratchSet(ctx, scratchSet)
  1013. if len(s.scratchGPUs) > 0 {
  1014. if sc := scratchSet.Scratch(s.scratchGPUs[0]); sc != nil {
  1015. ctx = compute.WithScratch(ctx, sc)
  1016. }
  1017. }
  1018. }
  1019. ids := s.tok.Encode(prompt)
  1020. if len(ids) == 0 {
  1021. return "", fmt.Errorf("empty prompt after tokenization")
  1022. }
  1023. // Create cache - use model's CacheFactory if available, otherwise PagedKVCache
  1024. var cache kvcache.KVCacheInterface
  1025. var pagedCache *kvcache.PagedKVCache
  1026. var modelCache model.KVCache
  1027. var cachedTokens int
  1028. if s.cacheFactory != nil {
  1029. // Model uses custom cache (e.g., KimiLinear with recurrent state)
  1030. mc, err := s.cacheFactory.CreateCache()
  1031. if err != nil {
  1032. return "", fmt.Errorf("model cache creation failed: %w", err)
  1033. }
  1034. modelCache = mc
  1035. // modelCache implements model.KVCache but not kvcache.KVCacheInterface
  1036. // We'll use it directly with the model's Forward function
  1037. } else {
  1038. // Standard attention model - use PagedKVCache
  1039. if s.blockPool == nil {
  1040. return "", fmt.Errorf("BlockPool not initialized")
  1041. }
  1042. modelCfg := s.eng.Model().Config()
  1043. pagedCache = kvcache.NewPagedKVCache(s.blockPool, kvcache.PagedCacheConfig{
  1044. NumLayers: modelCfg.NumLayers,
  1045. NumKVHeads: modelCfg.NumKVHeads,
  1046. HeadDim: modelCfg.HeadDim,
  1047. BlockSize: s.blockSize,
  1048. MaxSeqLen: s.maxSeqLen,
  1049. Device: s.blockPoolDevice,
  1050. GPU: s.blockPoolGPU,
  1051. }, requestID)
  1052. cachedTokens, err = pagedCache.AllocateForTokens(ids)
  1053. if err != nil {
  1054. pagedCache.Free()
  1055. return "", fmt.Errorf("PagedKVCache alloc failed: %w", err)
  1056. }
  1057. defer pagedCache.Free()
  1058. cache = pagedCache
  1059. if cachedTokens > 0 {
  1060. log.Printf("prefix_cache_hit request=%s cached_tokens=%d prompt_tokens=%d", requestID, cachedTokens, len(ids))
  1061. }
  1062. }
  1063. sampler := sample.New(sample.Config{
  1064. Temperature: float32(temperature),
  1065. TopK: topK,
  1066. TopP: float32(topP),
  1067. RepetitionPenalty: 1.1,
  1068. Seed: -1,
  1069. })
  1070. // Prefill - skip cached tokens if we have a prefix cache hit
  1071. prefillIDs := ids
  1072. if cachedTokens > 0 && cachedTokens < len(ids) {
  1073. prefillIDs = ids[cachedTokens:]
  1074. }
  1075. var logits tensor.Tensor
  1076. if modelCache != nil {
  1077. // Use model's cache (recurrent state) - no chunked prefill, process all at once
  1078. logits, err = s.prefillWithModelCache(ctx, modelCache, prefillIDs)
  1079. } else {
  1080. logits, err = s.prefill(ctx, cache, prefillIDs)
  1081. }
  1082. if err != nil {
  1083. return "", err
  1084. }
  1085. var nextToken int
  1086. chunk := s.prefillChunkSize
  1087. if chunk <= 0 {
  1088. chunk = 512
  1089. }
  1090. lastPartLen := len(prefillIDs) % chunk
  1091. if lastPartLen == 0 {
  1092. lastPartLen = min(chunk, len(prefillIDs))
  1093. }
  1094. rowIdx := lastPartLen - 1
  1095. if modelCache != nil {
  1096. // For recurrent models, we processed all tokens at once
  1097. rowIdx = len(prefillIDs) - 1
  1098. }
  1099. nextToken, err = sampleNextToken(logits, rowIdx, sampler, ids)
  1100. if err != nil {
  1101. return "", err
  1102. }
  1103. ids = append(ids, nextToken)
  1104. if pagedCache != nil {
  1105. pagedCache.AppendToken(nextToken)
  1106. }
  1107. var sb strings.Builder
  1108. sb.WriteString(s.tok.Decode([]int{nextToken}))
  1109. eosID := s.tok.EosID()
  1110. useBatcher := s.batcher != nil && pagedCache != nil
  1111. if useBatcher {
  1112. if _, ok := s.eng.Model().(model.BatchForwarder); !ok {
  1113. useBatcher = false
  1114. }
  1115. }
  1116. if useBatcher {
  1117. seq := &engine.DecodeSequence{
  1118. RequestID: requestID,
  1119. Ctx: ctx,
  1120. Cache: cache,
  1121. History: ids,
  1122. NextInputToken: nextToken,
  1123. Remaining: maxTokens - 1,
  1124. EosID: eosID,
  1125. Sampler: sampler,
  1126. }
  1127. events, err := s.batcher.RegisterDecode(seq)
  1128. if err != nil {
  1129. return "", err
  1130. }
  1131. for ev := range events {
  1132. if ev.Err != nil {
  1133. return "", ev.Err
  1134. }
  1135. if ev.Done {
  1136. break
  1137. }
  1138. ids = append(ids, ev.Token)
  1139. sb.WriteString(s.tok.Decode([]int{ev.Token}))
  1140. }
  1141. return sb.String(), nil
  1142. }
  1143. // Decode loop - handle both PagedKVCache and model's cache
  1144. for i := 1; i < maxTokens; i++ {
  1145. if nextToken == eosID {
  1146. break
  1147. }
  1148. select {
  1149. case <-ctx.Done():
  1150. return "", ctx.Err()
  1151. default:
  1152. }
  1153. input := createInputTensor([]int{nextToken})
  1154. positions := createPositionTensor(len(ids)-1, 1)
  1155. if modelCache != nil {
  1156. // Use model's cache (recurrent state)
  1157. before := modelCache.SeqLen()
  1158. logits, err = s.eng.Forward(ctx, input, positions, modelCache)
  1159. if err != nil {
  1160. return "", err
  1161. }
  1162. if modelCache.SeqLen() == before {
  1163. modelCache.Commit(1)
  1164. }
  1165. } else {
  1166. before := cache.SeqLen()
  1167. logits, err = s.eng.Forward(ctx, input, positions, cache)
  1168. if err != nil {
  1169. return "", err
  1170. }
  1171. if cache.SeqLen() == before {
  1172. cache.Commit(1)
  1173. }
  1174. }
  1175. recent := ids
  1176. if len(recent) > 64 {
  1177. recent = recent[len(recent)-64:]
  1178. }
  1179. nextToken, err = sampleNextToken(logits, 0, sampler, recent)
  1180. if err != nil {
  1181. return "", err
  1182. }
  1183. ids = append(ids, nextToken)
  1184. if pagedCache != nil {
  1185. pagedCache.AppendToken(nextToken)
  1186. }
  1187. sb.WriteString(s.tok.Decode([]int{nextToken}))
  1188. }
  1189. return sb.String(), nil
  1190. }
  1191. func (s *Server) generateStream(ctx context.Context, requestID string, prompt string, maxTokens int, temperature float64, topP float64, topK int, onPiece func(string) error) (string, error) {
  1192. scratchSet, releaseScratch, err := s.acquireScratchSet(ctx)
  1193. if err != nil {
  1194. return "", err
  1195. }
  1196. if scratchSet != nil {
  1197. defer releaseScratch()
  1198. ctx = compute.WithScratchSet(ctx, scratchSet)
  1199. if len(s.scratchGPUs) > 0 {
  1200. if sc := scratchSet.Scratch(s.scratchGPUs[0]); sc != nil {
  1201. ctx = compute.WithScratch(ctx, sc)
  1202. }
  1203. }
  1204. }
  1205. ids := s.tok.Encode(prompt)
  1206. if len(ids) == 0 {
  1207. return "", fmt.Errorf("empty prompt after tokenization")
  1208. }
  1209. // Create cache - use model's CacheFactory if available, otherwise PagedKVCache
  1210. var cache kvcache.KVCacheInterface
  1211. var pagedCache *kvcache.PagedKVCache
  1212. var modelCache model.KVCache
  1213. var cachedTokens int
  1214. if s.cacheFactory != nil {
  1215. // Model uses custom cache (e.g., KimiLinear with recurrent state)
  1216. mc, err := s.cacheFactory.CreateCache()
  1217. if err != nil {
  1218. return "", fmt.Errorf("model cache creation failed: %w", err)
  1219. }
  1220. modelCache = mc
  1221. } else {
  1222. // Standard attention model - use PagedKVCache
  1223. if s.blockPool == nil {
  1224. return "", fmt.Errorf("BlockPool not initialized")
  1225. }
  1226. modelCfg := s.eng.Model().Config()
  1227. pagedCache = kvcache.NewPagedKVCache(s.blockPool, kvcache.PagedCacheConfig{
  1228. NumLayers: modelCfg.NumLayers,
  1229. NumKVHeads: modelCfg.NumKVHeads,
  1230. HeadDim: modelCfg.HeadDim,
  1231. BlockSize: s.blockSize,
  1232. MaxSeqLen: s.maxSeqLen,
  1233. Device: s.blockPoolDevice,
  1234. GPU: s.blockPoolGPU,
  1235. }, requestID)
  1236. cachedTokens, err = pagedCache.AllocateForTokens(ids)
  1237. if err != nil {
  1238. pagedCache.Free()
  1239. return "", fmt.Errorf("PagedKVCache alloc failed: %w", err)
  1240. }
  1241. defer pagedCache.Free()
  1242. cache = pagedCache
  1243. if cachedTokens > 0 {
  1244. log.Printf("prefix_cache_hit request=%s cached_tokens=%d prompt_tokens=%d", requestID, cachedTokens, len(ids))
  1245. }
  1246. }
  1247. sampler := sample.New(sample.Config{
  1248. Temperature: float32(temperature),
  1249. TopK: topK,
  1250. TopP: float32(topP),
  1251. RepetitionPenalty: 1.1,
  1252. Seed: -1,
  1253. })
  1254. // Prefill - skip cached tokens if we have a prefix cache hit
  1255. prefillIDs := ids
  1256. if cachedTokens > 0 && cachedTokens < len(ids) {
  1257. prefillIDs = ids[cachedTokens:]
  1258. }
  1259. var logits tensor.Tensor
  1260. if modelCache != nil {
  1261. // Use model's cache (recurrent state) - no chunked prefill, process all at once
  1262. logits, err = s.prefillWithModelCache(ctx, modelCache, prefillIDs)
  1263. } else {
  1264. logits, err = s.prefill(ctx, cache, prefillIDs)
  1265. }
  1266. if err != nil {
  1267. return "", err
  1268. }
  1269. var nextToken int
  1270. chunk := s.prefillChunkSize
  1271. if chunk <= 0 {
  1272. chunk = 512
  1273. }
  1274. lastPartLen := len(prefillIDs) % chunk
  1275. if lastPartLen == 0 {
  1276. lastPartLen = min(chunk, len(prefillIDs))
  1277. }
  1278. rowIdx := lastPartLen - 1
  1279. if modelCache != nil {
  1280. // For recurrent models, we processed all tokens at once
  1281. rowIdx = len(prefillIDs) - 1
  1282. }
  1283. nextToken, err = sampleNextToken(logits, rowIdx, sampler, ids)
  1284. if err != nil {
  1285. return "", err
  1286. }
  1287. ids = append(ids, nextToken)
  1288. if pagedCache != nil {
  1289. pagedCache.AppendToken(nextToken)
  1290. }
  1291. var sb strings.Builder
  1292. first := s.tok.Decode([]int{nextToken})
  1293. sb.WriteString(first)
  1294. if err := onPiece(first); err != nil {
  1295. return sb.String(), err
  1296. }
  1297. eosID := s.tok.EosID()
  1298. useBatcher := s.batcher != nil && pagedCache != nil
  1299. if useBatcher {
  1300. if _, ok := s.eng.Model().(model.BatchForwarder); !ok {
  1301. useBatcher = false
  1302. }
  1303. }
  1304. if useBatcher {
  1305. seq := &engine.DecodeSequence{
  1306. RequestID: requestID,
  1307. Ctx: ctx,
  1308. Cache: cache,
  1309. History: ids,
  1310. NextInputToken: nextToken,
  1311. Remaining: maxTokens - 1,
  1312. EosID: eosID,
  1313. Sampler: sampler,
  1314. }
  1315. events, err := s.batcher.RegisterDecode(seq)
  1316. if err != nil {
  1317. return sb.String(), err
  1318. }
  1319. for ev := range events {
  1320. if ev.Err != nil {
  1321. return sb.String(), ev.Err
  1322. }
  1323. if ev.Done {
  1324. break
  1325. }
  1326. ids = append(ids, ev.Token)
  1327. piece := s.tok.Decode([]int{ev.Token})
  1328. sb.WriteString(piece)
  1329. if err := onPiece(piece); err != nil {
  1330. return sb.String(), err
  1331. }
  1332. }
  1333. return sb.String(), nil
  1334. }
  1335. // Decode loop - handle both PagedKVCache and model's cache
  1336. for i := 1; i < maxTokens; i++ {
  1337. if nextToken == eosID {
  1338. break
  1339. }
  1340. select {
  1341. case <-ctx.Done():
  1342. return sb.String(), ctx.Err()
  1343. default:
  1344. }
  1345. input := createInputTensor([]int{nextToken})
  1346. positions := createPositionTensor(len(ids)-1, 1)
  1347. if modelCache != nil {
  1348. // Use model's cache (recurrent state)
  1349. before := modelCache.SeqLen()
  1350. logits, err = s.eng.Forward(ctx, input, positions, modelCache)
  1351. if err != nil {
  1352. return sb.String(), err
  1353. }
  1354. if modelCache.SeqLen() == before {
  1355. modelCache.Commit(1)
  1356. }
  1357. } else {
  1358. before := cache.SeqLen()
  1359. logits, err = s.eng.Forward(ctx, input, positions, cache)
  1360. if err != nil {
  1361. return sb.String(), err
  1362. }
  1363. if cache.SeqLen() == before {
  1364. cache.Commit(1)
  1365. }
  1366. }
  1367. recent := ids
  1368. if len(recent) > 64 {
  1369. recent = recent[len(recent)-64:]
  1370. }
  1371. nextToken, err = sampleNextToken(logits, 0, sampler, recent)
  1372. if err != nil {
  1373. return sb.String(), err
  1374. }
  1375. ids = append(ids, nextToken)
  1376. if pagedCache != nil {
  1377. pagedCache.AppendToken(nextToken)
  1378. }
  1379. piece := s.tok.Decode([]int{nextToken})
  1380. sb.WriteString(piece)
  1381. if err := onPiece(piece); err != nil {
  1382. return sb.String(), err
  1383. }
  1384. }
  1385. return sb.String(), nil
  1386. }
  1387. func createInputTensor(ids []int) tensor.Tensor {
  1388. t := cpu.NewTensor(tensor.Shape{len(ids)}, nil)
  1389. data := t.DataFloat32()
  1390. for i, id := range ids {
  1391. data[i] = float32(id)
  1392. }
  1393. return t
  1394. }
  1395. func createPositionTensor(start, count int) tensor.Tensor {
  1396. t := cpu.NewTensor(tensor.Shape{count}, nil)
  1397. data := t.DataFloat32()
  1398. for i := 0; i < count; i++ {
  1399. data[i] = float32(start + i)
  1400. }
  1401. return t
  1402. }
  1403. func getLogitsRowCPU(logits tensor.Tensor, row int) []float32 {
  1404. if _, ok := logits.(*cpu.Tensor); !ok {
  1405. return nil
  1406. }
  1407. data := logits.Data().(unsafe.Pointer)
  1408. shape := logits.Shape()
  1409. vocabSize := shape[1]
  1410. slice := unsafe.Slice((*float32)(data), shape.NumElements())
  1411. return slice[row*vocabSize : (row+1)*vocabSize]
  1412. }
  1413. func sampleNextToken(logits tensor.Tensor, row int, sampler *sample.Sampler, recentTokens []int) (int, error) {
  1414. if logits == nil {
  1415. return 0, fmt.Errorf("nil logits")
  1416. }
  1417. if sampler == nil {
  1418. return 0, fmt.Errorf("nil sampler")
  1419. }
  1420. recent := recentTokens
  1421. if len(recent) > 64 {
  1422. recent = recent[len(recent)-64:]
  1423. }
  1424. if logitsCPU := getLogitsRowCPU(logits, row); logitsCPU != nil {
  1425. return sampler.Sample(logitsCPU, recent), nil
  1426. }
  1427. gpuLogits, ok := logits.(*cuda.Tensor)
  1428. if !ok {
  1429. return 0, fmt.Errorf("unexpected logits type %T", logits)
  1430. }
  1431. shape := gpuLogits.Shape()
  1432. if len(shape) != 2 {
  1433. return 0, fmt.Errorf("expected 2D logits, got shape %v", shape)
  1434. }
  1435. vocabSize := shape[1]
  1436. view, err := gpuLogits.ViewAt(tensor.Shape{vocabSize}, uintptr(row*vocabSize*4))
  1437. if err != nil {
  1438. return 0, err
  1439. }
  1440. cfg := sampler.Config()
  1441. k := cfg.TopK
  1442. if cfg.Temperature == 0 {
  1443. k = 1
  1444. }
  1445. // CUDA TopK kernel supports k<=64; fall back to full D2H if disabled or too large.
  1446. if k <= 0 || k > 64 {
  1447. host := make([]float32, vocabSize)
  1448. if err := view.CopyToHost(host); err != nil {
  1449. return 0, err
  1450. }
  1451. return sampler.Sample(host, recent), nil
  1452. }
  1453. repPenalty := cfg.RepetitionPenalty
  1454. if repPenalty <= 0 {
  1455. repPenalty = 1.0
  1456. }
  1457. repIDs := make([]int32, len(recent))
  1458. for i, t := range recent {
  1459. repIDs[i] = int32(t)
  1460. }
  1461. allIDs, allScores, blocks, err := cuda.TopKLogitsF32(view.Data().(unsafe.Pointer), vocabSize, repIDs, repPenalty, k, gpuLogits.GPU())
  1462. if err != nil {
  1463. return 0, err
  1464. }
  1465. cands := make([]struct {
  1466. id int32
  1467. score float32
  1468. }, 0, blocks*k)
  1469. for i := 0; i < blocks*k; i++ {
  1470. if allIDs[i] < 0 {
  1471. continue
  1472. }
  1473. cands = append(cands, struct {
  1474. id int32
  1475. score float32
  1476. }{id: allIDs[i], score: allScores[i]})
  1477. }
  1478. if len(cands) == 0 {
  1479. // Defensive fallback.
  1480. host := make([]float32, vocabSize)
  1481. if err := view.CopyToHost(host); err != nil {
  1482. return 0, err
  1483. }
  1484. return sampler.Sample(host, recent), nil
  1485. }
  1486. sort.Slice(cands, func(i, j int) bool { return cands[i].score > cands[j].score })
  1487. if len(cands) > k {
  1488. cands = cands[:k]
  1489. }
  1490. finalIDs := make([]int32, len(cands))
  1491. finalScores := make([]float32, len(cands))
  1492. for i := range cands {
  1493. finalIDs[i] = cands[i].id
  1494. finalScores[i] = cands[i].score
  1495. }
  1496. return sampler.SampleFromTopK(finalIDs, finalScores), nil
  1497. }