| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633 |
- package openai
- import (
- "context"
- "errors"
- "encoding/json"
- "fmt"
- "log"
- "math/rand"
- "net/http"
- "sort"
- "strings"
- "sync"
- "time"
- "unicode/utf8"
- "unsafe"
- "makarna/pkg/backend/cpu"
- "makarna/pkg/backend/cuda"
- "makarna/pkg/backend/device"
- "makarna/pkg/chat"
- "makarna/pkg/compute"
- "makarna/pkg/engine"
- "makarna/pkg/kvcache"
- "makarna/pkg/model"
- "makarna/pkg/sample"
- "makarna/pkg/tensor"
- "makarna/pkg/tokenizer"
- )
- type ChatCompletionRequest struct {
- Model string `json:"model"`
- Messages []ChatCompletionMessage `json:"messages"`
- Tools []any `json:"tools,omitempty"`
- Stream bool `json:"stream,omitempty"`
- SessionID string `json:"session_id,omitempty"`
- MaxTokens int `json:"max_tokens,omitempty"`
- Temperature float64 `json:"temperature,omitempty"`
- TopP float64 `json:"top_p,omitempty"`
- TopK int `json:"top_k,omitempty"`
- PresencePenalty float64 `json:"presence_penalty,omitempty"`
- FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
- }
- type CompletionRequest struct {
- Model string `json:"model"`
- Prompt string `json:"prompt"`
- Stream bool `json:"stream,omitempty"`
- MaxTokens int `json:"max_tokens,omitempty"`
- Temperature float64 `json:"temperature,omitempty"`
- TopP float64 `json:"top_p,omitempty"`
- TopK int `json:"top_k,omitempty"`
- }
- type utf8StreamBuffer struct {
- buf []byte
- }
- func (b *utf8StreamBuffer) Push(s string) string {
- if s == "" {
- return ""
- }
- b.buf = append(b.buf, []byte(s)...)
- outLen := 0
- for outLen < len(b.buf) {
- _, size := utf8.DecodeRune(b.buf[outLen:])
- if size == 0 {
- break
- }
- if size == 1 {
- if !utf8.FullRune(b.buf[outLen:]) {
- break
- }
- }
- outLen += size
- }
- if outLen == 0 {
- return ""
- }
- out := string(b.buf[:outLen])
- b.buf = b.buf[outLen:]
- return out
- }
- func (b *utf8StreamBuffer) Flush() string {
- if len(b.buf) == 0 {
- return ""
- }
- if !utf8.Valid(b.buf) {
- b.buf = nil
- return ""
- }
- out := string(b.buf)
- b.buf = nil
- return out
- }
- type ChatCompletionMessage struct {
- Role string `json:"role"`
- Content string `json:"content"`
- Name string `json:"name,omitempty"`
- ToolCallID string `json:"tool_call_id,omitempty"`
- }
- type ChatCompletionResponse struct {
- ID string `json:"id"`
- Object string `json:"object"`
- Created int64 `json:"created"`
- Model string `json:"model"`
- Usage ChatCompletionUsage `json:"usage"`
- Choices []ChatCompletionChoice `json:"choices"`
- }
- type CompletionResponse struct {
- ID string `json:"id"`
- Object string `json:"object"`
- Created int64 `json:"created"`
- Model string `json:"model"`
- Choices []CompletionChoice `json:"choices"`
- Usage CompletionUsage `json:"usage"`
- }
- type CompletionChoice struct {
- Index int `json:"index"`
- Text string `json:"text"`
- FinishReason string `json:"finish_reason"`
- LogProbs any `json:"logprobs"`
- }
- type CompletionUsage struct {
- PromptTokens int `json:"prompt_tokens"`
- CompletionTokens int `json:"completion_tokens"`
- TotalTokens int `json:"total_tokens"`
- }
- type openAIErrorResponse struct {
- Error openAIError `json:"error"`
- }
- type openAIError struct {
- Message string `json:"message"`
- Type string `json:"type"`
- Param *string `json:"param,omitempty"`
- Code *string `json:"code,omitempty"`
- }
- type contextLengthExceededError struct {
- PromptTokens int
- MaxSeqLen int
- }
- func (e *contextLengthExceededError) Error() string {
- return fmt.Sprintf("context length exceeded: prompt_tokens=%d max_seq_len=%d", e.PromptTokens, e.MaxSeqLen)
- }
- type ChatCompletionUsage struct {
- PromptTokens int `json:"prompt_tokens"`
- CompletionTokens int `json:"completion_tokens"`
- TotalTokens int `json:"total_tokens"`
- }
- type ChatCompletionChunkResponse struct {
- ID string `json:"id"`
- Object string `json:"object"`
- Created int64 `json:"created"`
- Model string `json:"model"`
- Choices []ChatCompletionChunkChoice `json:"choices"`
- }
- type ChatCompletionChunkChoice struct {
- Index int `json:"index"`
- Delta ChatDelta `json:"delta"`
- FinishReason *string `json:"finish_reason"`
- }
- type ChatDelta struct {
- Role string `json:"role,omitempty"`
- Content string `json:"content,omitempty"`
- }
- type ChatCompletionChoice struct {
- Index int `json:"index"`
- Message ChatCompletionOutMsg `json:"message"`
- FinishReason string `json:"finish_reason"`
- }
- type ChatCompletionOutMsg struct {
- Role string `json:"role"`
- Content string `json:"content"`
- ToolCalls []OpenAIToolCall `json:"tool_calls,omitempty"`
- }
- type OpenAIToolCall struct {
- ID string `json:"id"`
- Type string `json:"type"`
- Function OpenAIFunctionCall `json:"function"`
- }
- type OpenAIFunctionCall struct {
- Name string `json:"name"`
- Arguments string `json:"arguments"`
- }
- type Config struct {
- Listen string
- MaxSeqLen int
- BlockSize int
- KVCacheCPU bool
- EnableThinking bool
- PrefillChunkSize int
- MaxConcurrent int
- }
- type Server struct {
- eng *engine.Engine
- batcher *engine.Batcher
- tok *tokenizer.Tokenizer
- arch string
- modelID string
- maxSeqLen int
- blockSize int
- prefillChunkSize int
- mu sync.Mutex
- // global block pool for prefix caching (used by standard attention models)
- blockPool *kvcache.BlockPool
- blockPoolDevice tensor.DeviceType
- blockPoolGPU int
- // cacheFactory is set if model implements model.CacheFactory (e.g., KimiLinear, Mamba)
- // When set, we use the model's cache instead of PagedKVCache
- cacheFactory model.CacheFactory
- scratchPool chan *compute.ScratchSet
- scratchGPUs []int
- scratchBytes int
- enableThinking bool
- stripTokens []string
- }
- func Serve(eng *engine.Engine, tok *tokenizer.Tokenizer, arch string, cfg Config) error {
- if cfg.Listen == "" {
- cfg.Listen = ":8080"
- }
- if cfg.MaxSeqLen <= 0 {
- cfg.MaxSeqLen = 8192
- }
- if cfg.BlockSize <= 0 {
- cfg.BlockSize = 32
- }
- if cfg.PrefillChunkSize <= 0 {
- cfg.PrefillChunkSize = 512
- }
- if cfg.MaxConcurrent <= 0 {
- cfg.MaxConcurrent = 4
- }
- s := &Server{
- eng: eng,
- batcher: engine.NewBatcher(eng),
- tok: tok,
- arch: arch,
- modelID: arch,
- maxSeqLen: cfg.MaxSeqLen,
- blockSize: cfg.BlockSize,
- enableThinking: cfg.EnableThinking,
- prefillChunkSize: cfg.PrefillChunkSize,
- }
- // Check if model implements CacheFactory (e.g., KimiLinear, Mamba)
- if cf, ok := eng.Model().(model.CacheFactory); ok {
- s.cacheFactory = cf
- log.Printf("kv_cache: using model's CacheFactory (type=%d)", cf.CacheType())
- }
- // Only start batcher for models using PagedKVCache (standard attention)
- // Recurrent models don't support batching yet
- if s.batcher != nil && s.cacheFactory == nil {
- s.batcher.Start()
- } else {
- s.batcher = nil
- }
- // Initialize global BlockPool for prefix caching (only for standard attention models)
- modelCfg := eng.Model().Config()
- var (
- layerPlacements []tensor.DevicePlacement
- anyCUDALayer bool
- )
- if eng.Dispatcher() != nil {
- layerPlacements = make([]tensor.DevicePlacement, modelCfg.NumLayers)
- for i := 0; i < modelCfg.NumLayers; i++ {
- p := eng.Dispatcher().LayerPlacement(i).Normalize()
- layerPlacements[i] = p
- if p.Type == tensor.CUDA && p.GPU >= 0 {
- anyCUDALayer = true
- }
- }
- }
- // Skip BlockPool for models using CacheFactory (recurrent state models)
- if s.cacheFactory == nil {
- blocksPerRequest := (cfg.MaxSeqLen + cfg.BlockSize - 1) / cfg.BlockSize
- blocksPerLayer := blocksPerRequest * cfg.MaxConcurrent
- if blocksPerLayer <= 0 {
- return fmt.Errorf("invalid KV pool sizing: max_seq_len=%d block_size=%d max_concurrent=%d", cfg.MaxSeqLen, cfg.BlockSize, cfg.MaxConcurrent)
- }
- // Support mixed CPU/GPU offload: allocate KV blocks per-layer on the same device as that layer.
- // This avoids the pathological "KV on CPU forces attention on CPU" slowdown when not all layers fit on GPU.
- if !cfg.KVCacheCPU && device.CUDAAvailable() && anyCUDALayer {
- blockPool, err := kvcache.NewBlockPool(kvcache.BlockPoolConfig{
- NumLayers: modelCfg.NumLayers,
- NumKVHeads: modelCfg.NumKVHeads,
- HeadDim: modelCfg.HeadDim,
- BlockSize: cfg.BlockSize,
- NumBlocks: blocksPerLayer,
- Device: tensor.CUDA, // default; LayerPlacements can still contain CPU layers
- GPU: 0,
- LayerPlacements: layerPlacements,
- Preallocate: true,
- })
- if err != nil {
- return fmt.Errorf("KV cache reserve failed (max_seq_len=%d max_concurrent=%d): %w", cfg.MaxSeqLen, cfg.MaxConcurrent, err)
- }
- s.blockPool = blockPool
- s.blockPoolDevice = tensor.CUDA
- s.blockPoolGPU = 0
- log.Printf("kv_cache: device=mixed block_size=%d blocks_per_layer=%d", cfg.BlockSize, blocksPerLayer)
- } else {
- blockPool, err := kvcache.NewBlockPool(kvcache.BlockPoolConfig{
- NumLayers: modelCfg.NumLayers,
- NumKVHeads: modelCfg.NumKVHeads,
- HeadDim: modelCfg.HeadDim,
- BlockSize: cfg.BlockSize,
- NumBlocks: blocksPerLayer,
- Device: tensor.CPU,
- GPU: 0,
- })
- if err != nil {
- return fmt.Errorf("KV cache init failed: %w", err)
- }
- s.blockPool = blockPool
- s.blockPoolDevice = tensor.CPU
- s.blockPoolGPU = 0
- log.Printf("kv_cache: device=cpu block_size=%d blocks_per_layer=%d", cfg.BlockSize, blocksPerLayer)
- }
- }
- // Allocate scratch sets up-front so prefill/decode doesn't trigger cudaMalloc churn.
- if device.CUDAAvailable() && eng != nil && eng.Dispatcher() != nil {
- gpus := collectDispatcherGPUs(eng.Dispatcher(), modelCfg.NumLayers)
- if len(gpus) > 0 {
- seqLen := cfg.PrefillChunkSize
- if seqLen <= 0 {
- seqLen = 512
- }
- if cfg.MaxSeqLen > 0 && seqLen > cfg.MaxSeqLen {
- seqLen = cfg.MaxSeqLen
- }
- scratchBytes := estimateScratchBytes(modelCfg, seqLen)
- poolSize := cfg.MaxConcurrent
- if poolSize <= 0 {
- poolSize = 1
- }
- const minScratchBytes = 8 << 20
- var scratchErr error
- for scratchBytes >= minScratchBytes {
- created := make([]*compute.ScratchSet, 0, poolSize)
- ok := true
- for i := 0; i < poolSize; i++ {
- ss, err := compute.NewScratchSet(gpus, scratchBytes)
- if err != nil {
- scratchErr = err
- ok = false
- break
- }
- created = append(created, ss)
- }
- if ok {
- s.scratchPool = make(chan *compute.ScratchSet, poolSize)
- for _, ss := range created {
- s.scratchPool <- ss
- }
- s.scratchGPUs = gpus
- s.scratchBytes = scratchBytes
- log.Printf("scratch: gpus=%v bytes=%d sets=%d", gpus, scratchBytes, poolSize)
- break
- }
- for _, prev := range created {
- prev.Free()
- }
- scratchBytes /= 2
- }
- if s.scratchPool == nil && scratchErr != nil {
- log.Printf("scratch disabled (reserve failed): %v", scratchErr)
- }
- }
- }
- // Warm up GPU weight caches so the first request doesn't spend seconds uploading weights.
- // Run asynchronously so the server can start listening immediately (important for OpenWebUI discovery).
- if s.scratchPool != nil && s.eng != nil {
- go func() {
- ss := <-s.scratchPool
- defer func() { s.scratchPool <- ss }()
- warmCtx := context.Background()
- warmCtx = compute.WithScratchSet(warmCtx, ss)
- if len(s.scratchGPUs) > 0 {
- if sc := ss.Scratch(s.scratchGPUs[0]); sc != nil {
- warmCtx = compute.WithScratch(warmCtx, sc)
- }
- }
- warmSeqLen := cfg.PrefillChunkSize
- if warmSeqLen <= 0 {
- warmSeqLen = 512
- }
- if cfg.MaxSeqLen > 0 && warmSeqLen > cfg.MaxSeqLen {
- warmSeqLen = cfg.MaxSeqLen
- }
- warmIDs := make([]int, warmSeqLen)
- _, err := s.eng.Forward(warmCtx, createInputTensor(warmIDs), createPositionTensor(0, warmSeqLen), nil)
- if err != nil {
- log.Printf("warmup forward failed: %v", err)
- return
- }
- log.Printf("warmup forward ok seq_len=%d", warmSeqLen)
- compute.LogWeightCacheSummary()
- }()
- }
- if tok != nil {
- for _, st := range tok.AddedTokenStrings() {
- if strings.Contains(strings.ToLower(st), "think") {
- continue
- }
- s.stripTokens = append(s.stripTokens, st)
- }
- }
- mux := http.NewServeMux()
- mux.HandleFunc("/v1/chat/completions", s.handleChatCompletions)
- mux.HandleFunc("/v1/completions", s.handleCompletions)
- mux.HandleFunc("/v1/models", s.handleModels)
- mux.HandleFunc("/", s.handleRoot)
- h := withCORS(mux)
- log.Printf("listening on %s (arch=%s, cuda=%v)", cfg.Listen, arch, device.CUDAAvailable())
- return http.ListenAndServe(cfg.Listen, h)
- }
- func withCORS(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Access-Control-Allow-Origin", "*")
- w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
- w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
- if r.Method == http.MethodOptions {
- w.WriteHeader(http.StatusNoContent)
- return
- }
- next.ServeHTTP(w, r)
- })
- }
- func (s *Server) handleModels(w http.ResponseWriter, r *http.Request) {
- if r.Method != http.MethodGet {
- http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
- return
- }
- resp := map[string]any{
- "object": "list",
- "data": []any{map[string]any{
- "id": s.modelID,
- "object": "model",
- "owned_by": "local",
- }},
- }
- writeJSON(w, resp)
- }
- func (s *Server) handleRoot(w http.ResponseWriter, r *http.Request) {
- if r.URL.Path != "/" {
- http.NotFound(w, r)
- return
- }
- if r.Method != http.MethodGet {
- http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
- return
- }
- writeJSON(w, map[string]any{
- "status": "ok",
- "api": map[string]any{
- "chat_completions": "/v1/chat/completions",
- "completions": "/v1/completions",
- "models": "/v1/models",
- },
- })
- }
- func (s *Server) handleCompletions(w http.ResponseWriter, r *http.Request) {
- if r.Method != http.MethodPost {
- http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
- return
- }
- start := time.Now()
- var req CompletionRequest
- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
- http.Error(w, "bad json", http.StatusBadRequest)
- return
- }
- if req.Model == "" {
- req.Model = s.modelID
- }
- if req.Prompt == "" {
- http.Error(w, "prompt required", http.StatusBadRequest)
- return
- }
- if req.Stream {
- http.Error(w, "stream not implemented", http.StatusNotImplemented)
- return
- }
- maxTokens := req.MaxTokens
- if maxTokens <= 0 {
- maxTokens = 128
- }
- temp := req.Temperature
- if temp == 0 {
- temp = 0.7
- }
- topP := req.TopP
- if topP == 0 {
- topP = 0.9
- }
- topK := req.TopK
- if topK == 0 {
- topK = 40
- }
- requestID := fmt.Sprintf("cmpl_%d", rand.Int63())
- log.Printf("/v1/completions id=%s stream=false remote=%s model=%s max_tokens=%d", requestID, r.RemoteAddr, req.Model, maxTokens)
- outText, genErr := s.generate(r.Context(), requestID, req.Prompt, maxTokens, temp, topP, topK)
- if genErr != nil {
- log.Printf("/v1/completions id=%s error=%v", requestID, genErr)
- http.Error(w, fmt.Sprintf("generate: %v", genErr), http.StatusInternalServerError)
- return
- }
- promptTokens := 0
- completionTokens := 0
- if s.tok != nil {
- promptTokens = len(s.tok.Encode(req.Prompt))
- completionTokens = len(s.tok.Encode(outText))
- }
- resp := CompletionResponse{
- ID: requestID,
- Object: "text_completion",
- Created: time.Now().Unix(),
- Model: req.Model,
- Choices: []CompletionChoice{{
- Index: 0,
- Text: outText,
- FinishReason: "stop",
- LogProbs: nil,
- }},
- Usage: CompletionUsage{
- PromptTokens: promptTokens,
- CompletionTokens: completionTokens,
- TotalTokens: promptTokens + completionTokens,
- },
- }
- log.Printf("/v1/completions id=%s done prompt_tokens=%d completion_tokens=%d ms=%d", requestID, promptTokens, completionTokens, time.Since(start).Milliseconds())
- writeJSON(w, resp)
- }
- func (s *Server) handleChatCompletions(w http.ResponseWriter, r *http.Request) {
- if r.Method != http.MethodPost {
- http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
- return
- }
- start := time.Now()
- var req ChatCompletionRequest
- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
- http.Error(w, "bad json", http.StatusBadRequest)
- return
- }
- if req.Model == "" {
- req.Model = s.modelID
- }
- if len(req.Messages) == 0 {
- http.Error(w, "messages required", http.StatusBadRequest)
- return
- }
- if req.Stream {
- s.handleChatCompletionsStream(w, r, req)
- return
- }
- requestID := fmt.Sprintf("chatcmpl_%d", rand.Int63())
- msgs := make([]chat.Message, 0, len(req.Messages))
- for _, m := range req.Messages {
- role := strings.ToLower(m.Role)
- msgs = append(msgs, chat.Message{Role: role, Content: m.Content})
- }
- prompt, promptTokens, truncated, err := s.renderPromptWithinBudget(msgs, req.Tools)
- if err != nil {
- var cl *contextLengthExceededError
- if errors.As(err, &cl) {
- writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("This model's maximum context length is %d tokens. However, your messages resulted in %d tokens.", cl.MaxSeqLen, cl.PromptTokens), "context_length_exceeded", "messages")
- return
- }
- http.Error(w, fmt.Sprintf("render prompt: %v", err), http.StatusInternalServerError)
- return
- }
- if truncated {
- log.Printf("/v1/chat/completions id=%s prompt_truncated=true", requestID)
- }
- maxTokens := s.computeMaxTokens(req.MaxTokens, promptTokens)
- log.Printf("/v1/chat/completions id=%s stream=false remote=%s model=%s messages=%d max_tokens=%d prompt_tokens=%d", requestID, r.RemoteAddr, req.Model, len(req.Messages), maxTokens, promptTokens)
- temp := req.Temperature
- if temp == 0 {
- temp = 0.7
- }
- topP := req.TopP
- if topP == 0 {
- topP = 0.9
- }
- topK := req.TopK
- if topK == 0 {
- topK = 40
- }
- outText, genErr := s.generate(r.Context(), requestID, prompt, maxTokens, temp, topP, topK)
- if genErr != nil {
- log.Printf("/v1/chat/completions id=%s error=%v", requestID, genErr)
- http.Error(w, fmt.Sprintf("generate: %v", genErr), http.StatusInternalServerError)
- return
- }
- genTokens := 0
- if s.tok != nil {
- genTokens = len(s.tok.Encode(outText))
- }
- content := strings.TrimSpace(outText)
- content, calls, err := chat.ExtractToolCalls(content)
- if err != nil {
- http.Error(w, fmt.Sprintf("parse tool_calls: %v", err), http.StatusInternalServerError)
- return
- }
- content = strings.TrimSpace(content)
- content = sanitizeAssistantContent(content, s.stripTokens)
- content = strings.TrimSpace(content)
- var outCalls []OpenAIToolCall
- for i, c := range calls {
- outCalls = append(outCalls, OpenAIToolCall{
- ID: fmt.Sprintf("call_%d_%d", time.Now().UnixNano(), i),
- Type: "function",
- Function: OpenAIFunctionCall{
- Name: c.Name,
- Arguments: string(bytesOrEmptyObject(c.Arguments)),
- },
- })
- }
- finish := "stop"
- if len(outCalls) > 0 {
- finish = "tool_calls"
- }
- log.Printf("/v1/chat/completions id=%s done stream=false prompt_tokens=%d gen_tokens=%d tool_calls=%d ms=%d", requestID, promptTokens, genTokens, len(outCalls), time.Since(start).Milliseconds())
- resp := ChatCompletionResponse{
- ID: requestID,
- Object: "chat.completion",
- Created: time.Now().Unix(),
- Model: req.Model,
- Usage: ChatCompletionUsage{
- PromptTokens: promptTokens,
- CompletionTokens: genTokens,
- TotalTokens: promptTokens + genTokens,
- },
- Choices: []ChatCompletionChoice{{
- Index: 0,
- Message: ChatCompletionOutMsg{
- Role: "assistant",
- Content: content,
- ToolCalls: outCalls,
- },
- FinishReason: finish,
- }},
- }
- writeJSON(w, resp)
- }
- func (s *Server) handleChatCompletionsStream(w http.ResponseWriter, r *http.Request, req ChatCompletionRequest) {
- flusher, ok := w.(http.Flusher)
- if !ok {
- http.Error(w, "streaming unsupported", http.StatusInternalServerError)
- return
- }
- start := time.Now()
- msg := make([]chat.Message, 0, len(req.Messages))
- for _, m := range req.Messages {
- role := strings.ToLower(m.Role)
- msg = append(msg, chat.Message{Role: role, Content: m.Content})
- }
- id := fmt.Sprintf("chatcmpl_%d", rand.Int63())
- created := time.Now().Unix()
- prompt, promptTokens, truncated, err := s.renderPromptWithinBudget(msg, req.Tools)
- if err != nil {
- var cl *contextLengthExceededError
- if errors.As(err, &cl) {
- writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("This model's maximum context length is %d tokens. However, your messages resulted in %d tokens.", cl.MaxSeqLen, cl.PromptTokens), "context_length_exceeded", "messages")
- return
- }
- http.Error(w, fmt.Sprintf("render prompt: %v", err), http.StatusInternalServerError)
- return
- }
- w.Header().Set("Content-Type", "text/event-stream")
- w.Header().Set("Cache-Control", "no-cache")
- w.Header().Set("Connection", "keep-alive")
- maxTokens := s.computeMaxTokens(req.MaxTokens, promptTokens)
- temp := req.Temperature
- if temp == 0 {
- temp = 0.7
- }
- topP := req.TopP
- if topP == 0 {
- topP = 0.9
- }
- topK := req.TopK
- if topK == 0 {
- topK = 40
- }
- if truncated {
- log.Printf("/v1/chat/completions id=%s prompt_truncated=true", id)
- }
- log.Printf("/v1/chat/completions id=%s stream=true remote=%s model=%s messages=%d max_tokens=%d prompt_tokens=%d", id, r.RemoteAddr, req.Model, len(req.Messages), maxTokens, promptTokens)
- if err := writeSSEJSON(w, ChatCompletionChunkResponse{
- ID: id,
- Object: "chat.completion.chunk",
- Created: created,
- Model: req.Model,
- Choices: []ChatCompletionChunkChoice{{
- Index: 0,
- Delta: ChatDelta{Role: "assistant"},
- }},
- }); err != nil {
- return
- }
- flusher.Flush()
- genTokens := 0
- utf8buf := &utf8StreamBuffer{}
- _, genErr := s.generateStream(r.Context(), id, prompt, maxTokens, temp, topP, topK, func(piece string) error {
- genTokens++
- // Decode can emit partial UTF-8 sequences across tokens.
- // Buffer first, then sanitize, otherwise sanitize may corrupt multibyte sequences.
- piece = utf8buf.Push(piece)
- if piece == "" {
- return nil
- }
- piece = sanitizeAssistantContent(piece, s.stripTokens)
- if err := writeSSEJSON(w, ChatCompletionChunkResponse{
- ID: id,
- Object: "chat.completion.chunk",
- Created: created,
- Model: req.Model,
- Choices: []ChatCompletionChunkChoice{{
- Index: 0,
- Delta: ChatDelta{Content: piece},
- }},
- }); err != nil {
- return err
- }
- flusher.Flush()
- return nil
- })
- if genErr != nil {
- log.Printf("/v1/chat/completions id=%s error=%v", id, genErr)
- if errors.Is(genErr, context.Canceled) || errors.Is(r.Context().Err(), context.Canceled) {
- finish := "cancelled"
- _ = writeSSEJSON(w, ChatCompletionChunkResponse{
- ID: id,
- Object: "chat.completion.chunk",
- Created: created,
- Model: req.Model,
- Choices: []ChatCompletionChunkChoice{{
- Index: 0,
- Delta: ChatDelta{},
- FinishReason: &finish,
- }},
- })
- flusher.Flush()
- _, _ = fmt.Fprint(w, "data: [DONE]\n\n")
- flusher.Flush()
- }
- return
- }
- if tail := utf8buf.Flush(); tail != "" {
- tail = sanitizeAssistantContent(tail, s.stripTokens)
- _ = writeSSEJSON(w, ChatCompletionChunkResponse{
- ID: id,
- Object: "chat.completion.chunk",
- Created: created,
- Model: req.Model,
- Choices: []ChatCompletionChunkChoice{{
- Index: 0,
- Delta: ChatDelta{Content: tail},
- }},
- })
- flusher.Flush()
- }
- finish := "stop"
- if err := writeSSEJSON(w, ChatCompletionChunkResponse{
- ID: id,
- Object: "chat.completion.chunk",
- Created: created,
- Model: req.Model,
- Choices: []ChatCompletionChunkChoice{{
- Index: 0,
- Delta: ChatDelta{},
- FinishReason: &finish,
- }},
- }); err != nil {
- return
- }
- flusher.Flush()
- _, _ = fmt.Fprint(w, "data: [DONE]\n\n")
- flusher.Flush()
- log.Printf("/v1/chat/completions id=%s done stream=true prompt_tokens=%d gen_tokens=%d ms=%d", id, promptTokens, genTokens, time.Since(start).Milliseconds())
- }
- func bytesOrEmptyObject(b []byte) []byte {
- if len(b) == 0 {
- return []byte("{}")
- }
- return b
- }
- func sanitizeAssistantContent(s string, stripTokens []string) string {
- if s == "" {
- return s
- }
- for _, tok := range stripTokens {
- if tok == "" {
- continue
- }
- s = strings.ReplaceAll(s, tok, "")
- }
- return s
- }
- func (s *Server) computeMaxTokens(reqMaxTokens int, promptTokens int) int {
- budget := s.maxSeqLen - promptTokens
- if budget < 1 {
- budget = 1
- }
- if reqMaxTokens <= 0 {
- return budget
- }
- if reqMaxTokens < 1 {
- return 1
- }
- if reqMaxTokens > budget {
- return budget
- }
- return reqMaxTokens
- }
- func (s *Server) renderPromptWithinBudget(msgs []chat.Message, tools []any) (string, int, bool, error) {
- p, err := chat.RenderForArchitecture(s.arch, msgs, chat.Options{
- AddGenerationPrompt: true,
- EnableThinking: s.enableThinking,
- Tools: tools,
- })
- if err != nil {
- return "", 0, false, err
- }
- if s.tok == nil {
- return p, 0, false, nil
- }
- pt := len(s.tok.Encode(p))
- if pt > s.maxSeqLen {
- return "", pt, false, &contextLengthExceededError{PromptTokens: pt, MaxSeqLen: s.maxSeqLen}
- }
- return p, pt, false, nil
- }
- func writeOpenAIError(w http.ResponseWriter, status int, message string, code string, param string) {
- w.Header().Set("Content-Type", "application/json")
- w.WriteHeader(status)
- typeStr := "invalid_request_error"
- paramCopy := param
- codeCopy := code
- _ = json.NewEncoder(w).Encode(openAIErrorResponse{
- Error: openAIError{
- Message: message,
- Type: typeStr,
- Param: ¶mCopy,
- Code: &codeCopy,
- },
- })
- }
- func writeJSON(w http.ResponseWriter, v any) {
- w.Header().Set("Content-Type", "application/json")
- enc := json.NewEncoder(w)
- enc.SetEscapeHTML(false)
- _ = enc.Encode(v)
- }
- func writeSSEJSON(w http.ResponseWriter, v any) error {
- b, err := json.Marshal(v)
- if err != nil {
- return err
- }
- _, err = fmt.Fprintf(w, "data: %s\n\n", b)
- return err
- }
- func collectDispatcherGPUs(d *device.DeviceDispatcher, numLayers int) []int {
- if d == nil || numLayers <= 0 {
- return nil
- }
- seen := make(map[int]struct{})
- out := make([]int, 0, 4)
- for i := 0; i < numLayers; i++ {
- p := d.LayerPlacement(i).Normalize()
- if p.Type != tensor.CUDA || p.GPU < 0 {
- continue
- }
- if _, ok := seen[p.GPU]; ok {
- continue
- }
- seen[p.GPU] = struct{}{}
- out = append(out, p.GPU)
- }
- sort.Ints(out)
- return out
- }
- func estimateScratchBytes(cfg *model.Config, seqLen int) int {
- if cfg == nil || seqLen <= 0 {
- return compute.DefaultScratchBytes
- }
- hiddenSize := cfg.HiddenSize
- numHeads := cfg.NumHeads
- numKVHeads := cfg.NumKVHeads
- if numKVHeads == 0 {
- numKVHeads = numHeads
- }
- headDim := cfg.HeadDim
- if headDim == 0 && numHeads > 0 {
- headDim = hiddenSize / numHeads
- }
- intermediate := cfg.Intermediate
- if intermediate == 0 {
- intermediate = hiddenSize * 4
- }
- qDim := numHeads * headDim
- kvDim := numKVHeads * headDim
- add := func(rows, cols int) int64 {
- if rows <= 0 || cols <= 0 {
- return 0
- }
- return int64(rows) * int64(cols) * 4
- }
- var bytes int64
- // Attention scratch.
- bytes += add(seqLen, qDim) // qOut
- bytes += add(seqLen, kvDim) // kOut
- bytes += add(seqLen, kvDim) // vOut
- bytes += add(seqLen, qDim) // attnOut
- bytes += add(seqLen, hiddenSize) // attnProj
- // MLP scratch.
- bytes += add(seqLen, intermediate) // gate
- bytes += add(seqLen, intermediate) // up
- bytes += add(seqLen, intermediate) // act
- bytes += add(seqLen, hiddenSize) // mlpOut
- // Extra space for ptr tables / int32 slices / alignment slack.
- bytes += 8 << 20
- // Round up to 16MB.
- const align = int64(16 << 20)
- if rem := bytes % align; rem != 0 {
- bytes += align - rem
- }
- if bytes < int64(compute.DefaultScratchBytes) {
- bytes = int64(compute.DefaultScratchBytes)
- }
- maxInt := int64(int(^uint(0) >> 1))
- if bytes > maxInt {
- bytes = maxInt
- }
- return int(bytes)
- }
- func (s *Server) acquireScratchSet(ctx context.Context) (*compute.ScratchSet, func(), error) {
- if s == nil || s.scratchPool == nil {
- return nil, func() {}, nil
- }
- if ctx == nil {
- ctx = context.Background()
- }
- select {
- case ss := <-s.scratchPool:
- if ss != nil {
- ss.Reset()
- }
- release := func() {
- if ss != nil {
- ss.Reset()
- }
- s.scratchPool <- ss
- }
- return ss, release, nil
- case <-ctx.Done():
- return nil, nil, ctx.Err()
- }
- }
- func (s *Server) prefill(ctx context.Context, cache kvcache.KVCacheInterface, ids []int) (tensor.Tensor, error) {
- if len(ids) == 0 {
- return nil, fmt.Errorf("no prompt tokens to prefill")
- }
- chunk := s.prefillChunkSize
- if chunk <= 0 {
- chunk = 512
- }
- var logits tensor.Tensor
- for start := 0; start < len(ids); start += chunk {
- end := start + chunk
- if end > len(ids) {
- end = len(ids)
- }
- part := ids[start:end]
- input := createInputTensor(part)
- positions := createPositionTensor(cache.SeqLen(), len(part))
- before := cache.SeqLen()
- out, err := s.eng.Forward(ctx, input, positions, cache)
- if err != nil {
- return nil, err
- }
- logits = out
- // Some model implementations advance the cache internally.
- // If the cache didn't advance, commit here.
- if cache.SeqLen() == before {
- cache.Commit(len(part))
- }
- }
- return logits, nil
- }
- // prefillWithModelCache processes all tokens at once using the model's custom cache (for recurrent models).
- // Unlike prefill which chunks tokens, recurrent models like KimiLinear need to process sequentially
- // to properly update their recurrent state.
- func (s *Server) prefillWithModelCache(ctx context.Context, cache model.KVCache, ids []int) (tensor.Tensor, error) {
- if len(ids) == 0 {
- return nil, fmt.Errorf("no prompt tokens to prefill")
- }
- input := createInputTensor(ids)
- positions := createPositionTensor(cache.SeqLen(), len(ids))
- before := cache.SeqLen()
- logits, err := s.eng.Forward(ctx, input, positions, cache)
- if err != nil {
- return nil, err
- }
- // Commit if cache didn't advance internally
- if cache.SeqLen() == before {
- cache.Commit(len(ids))
- }
- return logits, nil
- }
- func (s *Server) generate(ctx context.Context, requestID string, prompt string, maxTokens int, temperature float64, topP float64, topK int) (string, error) {
- scratchSet, releaseScratch, err := s.acquireScratchSet(ctx)
- if err != nil {
- return "", err
- }
- if scratchSet != nil {
- defer releaseScratch()
- ctx = compute.WithScratchSet(ctx, scratchSet)
- if len(s.scratchGPUs) > 0 {
- if sc := scratchSet.Scratch(s.scratchGPUs[0]); sc != nil {
- ctx = compute.WithScratch(ctx, sc)
- }
- }
- }
- ids := s.tok.Encode(prompt)
- if len(ids) == 0 {
- return "", fmt.Errorf("empty prompt after tokenization")
- }
- // Create cache - use model's CacheFactory if available, otherwise PagedKVCache
- var cache kvcache.KVCacheInterface
- var pagedCache *kvcache.PagedKVCache
- var modelCache model.KVCache
- var cachedTokens int
- if s.cacheFactory != nil {
- // Model uses custom cache (e.g., KimiLinear with recurrent state)
- mc, err := s.cacheFactory.CreateCache()
- if err != nil {
- return "", fmt.Errorf("model cache creation failed: %w", err)
- }
- modelCache = mc
- // modelCache implements model.KVCache but not kvcache.KVCacheInterface
- // We'll use it directly with the model's Forward function
- } else {
- // Standard attention model - use PagedKVCache
- if s.blockPool == nil {
- return "", fmt.Errorf("BlockPool not initialized")
- }
- modelCfg := s.eng.Model().Config()
- pagedCache = kvcache.NewPagedKVCache(s.blockPool, kvcache.PagedCacheConfig{
- NumLayers: modelCfg.NumLayers,
- NumKVHeads: modelCfg.NumKVHeads,
- HeadDim: modelCfg.HeadDim,
- BlockSize: s.blockSize,
- MaxSeqLen: s.maxSeqLen,
- Device: s.blockPoolDevice,
- GPU: s.blockPoolGPU,
- }, requestID)
- cachedTokens, err = pagedCache.AllocateForTokens(ids)
- if err != nil {
- pagedCache.Free()
- return "", fmt.Errorf("PagedKVCache alloc failed: %w", err)
- }
- defer pagedCache.Free()
- cache = pagedCache
- if cachedTokens > 0 {
- log.Printf("prefix_cache_hit request=%s cached_tokens=%d prompt_tokens=%d", requestID, cachedTokens, len(ids))
- }
- }
- sampler := sample.New(sample.Config{
- Temperature: float32(temperature),
- TopK: topK,
- TopP: float32(topP),
- RepetitionPenalty: 1.1,
- Seed: -1,
- })
- // Prefill - skip cached tokens if we have a prefix cache hit
- prefillIDs := ids
- if cachedTokens > 0 && cachedTokens < len(ids) {
- prefillIDs = ids[cachedTokens:]
- }
- var logits tensor.Tensor
- if modelCache != nil {
- // Use model's cache (recurrent state) - no chunked prefill, process all at once
- logits, err = s.prefillWithModelCache(ctx, modelCache, prefillIDs)
- } else {
- logits, err = s.prefill(ctx, cache, prefillIDs)
- }
- if err != nil {
- return "", err
- }
- var nextToken int
- chunk := s.prefillChunkSize
- if chunk <= 0 {
- chunk = 512
- }
- lastPartLen := len(prefillIDs) % chunk
- if lastPartLen == 0 {
- lastPartLen = min(chunk, len(prefillIDs))
- }
- rowIdx := lastPartLen - 1
- if modelCache != nil {
- // For recurrent models, we processed all tokens at once
- rowIdx = len(prefillIDs) - 1
- }
- nextToken, err = sampleNextToken(logits, rowIdx, sampler, ids)
- if err != nil {
- return "", err
- }
- ids = append(ids, nextToken)
- if pagedCache != nil {
- pagedCache.AppendToken(nextToken)
- }
- var sb strings.Builder
- sb.WriteString(s.tok.Decode([]int{nextToken}))
- eosID := s.tok.EosID()
- useBatcher := s.batcher != nil && pagedCache != nil
- if useBatcher {
- if _, ok := s.eng.Model().(model.BatchForwarder); !ok {
- useBatcher = false
- }
- }
- if useBatcher {
- seq := &engine.DecodeSequence{
- RequestID: requestID,
- Ctx: ctx,
- Cache: cache,
- History: ids,
- NextInputToken: nextToken,
- Remaining: maxTokens - 1,
- EosID: eosID,
- Sampler: sampler,
- }
- events, err := s.batcher.RegisterDecode(seq)
- if err != nil {
- return "", err
- }
- for ev := range events {
- if ev.Err != nil {
- return "", ev.Err
- }
- if ev.Done {
- break
- }
- ids = append(ids, ev.Token)
- sb.WriteString(s.tok.Decode([]int{ev.Token}))
- }
- return sb.String(), nil
- }
- // Decode loop - handle both PagedKVCache and model's cache
- for i := 1; i < maxTokens; i++ {
- if nextToken == eosID {
- break
- }
- select {
- case <-ctx.Done():
- return "", ctx.Err()
- default:
- }
- input := createInputTensor([]int{nextToken})
- positions := createPositionTensor(len(ids)-1, 1)
- if modelCache != nil {
- // Use model's cache (recurrent state)
- before := modelCache.SeqLen()
- logits, err = s.eng.Forward(ctx, input, positions, modelCache)
- if err != nil {
- return "", err
- }
- if modelCache.SeqLen() == before {
- modelCache.Commit(1)
- }
- } else {
- before := cache.SeqLen()
- logits, err = s.eng.Forward(ctx, input, positions, cache)
- if err != nil {
- return "", err
- }
- if cache.SeqLen() == before {
- cache.Commit(1)
- }
- }
- recent := ids
- if len(recent) > 64 {
- recent = recent[len(recent)-64:]
- }
- nextToken, err = sampleNextToken(logits, 0, sampler, recent)
- if err != nil {
- return "", err
- }
- ids = append(ids, nextToken)
- if pagedCache != nil {
- pagedCache.AppendToken(nextToken)
- }
- sb.WriteString(s.tok.Decode([]int{nextToken}))
- }
- return sb.String(), nil
- }
- func (s *Server) generateStream(ctx context.Context, requestID string, prompt string, maxTokens int, temperature float64, topP float64, topK int, onPiece func(string) error) (string, error) {
- scratchSet, releaseScratch, err := s.acquireScratchSet(ctx)
- if err != nil {
- return "", err
- }
- if scratchSet != nil {
- defer releaseScratch()
- ctx = compute.WithScratchSet(ctx, scratchSet)
- if len(s.scratchGPUs) > 0 {
- if sc := scratchSet.Scratch(s.scratchGPUs[0]); sc != nil {
- ctx = compute.WithScratch(ctx, sc)
- }
- }
- }
- ids := s.tok.Encode(prompt)
- if len(ids) == 0 {
- return "", fmt.Errorf("empty prompt after tokenization")
- }
- // Create cache - use model's CacheFactory if available, otherwise PagedKVCache
- var cache kvcache.KVCacheInterface
- var pagedCache *kvcache.PagedKVCache
- var modelCache model.KVCache
- var cachedTokens int
- if s.cacheFactory != nil {
- // Model uses custom cache (e.g., KimiLinear with recurrent state)
- mc, err := s.cacheFactory.CreateCache()
- if err != nil {
- return "", fmt.Errorf("model cache creation failed: %w", err)
- }
- modelCache = mc
- } else {
- // Standard attention model - use PagedKVCache
- if s.blockPool == nil {
- return "", fmt.Errorf("BlockPool not initialized")
- }
- modelCfg := s.eng.Model().Config()
- pagedCache = kvcache.NewPagedKVCache(s.blockPool, kvcache.PagedCacheConfig{
- NumLayers: modelCfg.NumLayers,
- NumKVHeads: modelCfg.NumKVHeads,
- HeadDim: modelCfg.HeadDim,
- BlockSize: s.blockSize,
- MaxSeqLen: s.maxSeqLen,
- Device: s.blockPoolDevice,
- GPU: s.blockPoolGPU,
- }, requestID)
- cachedTokens, err = pagedCache.AllocateForTokens(ids)
- if err != nil {
- pagedCache.Free()
- return "", fmt.Errorf("PagedKVCache alloc failed: %w", err)
- }
- defer pagedCache.Free()
- cache = pagedCache
- if cachedTokens > 0 {
- log.Printf("prefix_cache_hit request=%s cached_tokens=%d prompt_tokens=%d", requestID, cachedTokens, len(ids))
- }
- }
- sampler := sample.New(sample.Config{
- Temperature: float32(temperature),
- TopK: topK,
- TopP: float32(topP),
- RepetitionPenalty: 1.1,
- Seed: -1,
- })
- // Prefill - skip cached tokens if we have a prefix cache hit
- prefillIDs := ids
- if cachedTokens > 0 && cachedTokens < len(ids) {
- prefillIDs = ids[cachedTokens:]
- }
- var logits tensor.Tensor
- if modelCache != nil {
- // Use model's cache (recurrent state) - no chunked prefill, process all at once
- logits, err = s.prefillWithModelCache(ctx, modelCache, prefillIDs)
- } else {
- logits, err = s.prefill(ctx, cache, prefillIDs)
- }
- if err != nil {
- return "", err
- }
- var nextToken int
- chunk := s.prefillChunkSize
- if chunk <= 0 {
- chunk = 512
- }
- lastPartLen := len(prefillIDs) % chunk
- if lastPartLen == 0 {
- lastPartLen = min(chunk, len(prefillIDs))
- }
- rowIdx := lastPartLen - 1
- if modelCache != nil {
- // For recurrent models, we processed all tokens at once
- rowIdx = len(prefillIDs) - 1
- }
- nextToken, err = sampleNextToken(logits, rowIdx, sampler, ids)
- if err != nil {
- return "", err
- }
- ids = append(ids, nextToken)
- if pagedCache != nil {
- pagedCache.AppendToken(nextToken)
- }
- var sb strings.Builder
- first := s.tok.Decode([]int{nextToken})
- sb.WriteString(first)
- if err := onPiece(first); err != nil {
- return sb.String(), err
- }
- eosID := s.tok.EosID()
- useBatcher := s.batcher != nil && pagedCache != nil
- if useBatcher {
- if _, ok := s.eng.Model().(model.BatchForwarder); !ok {
- useBatcher = false
- }
- }
- if useBatcher {
- seq := &engine.DecodeSequence{
- RequestID: requestID,
- Ctx: ctx,
- Cache: cache,
- History: ids,
- NextInputToken: nextToken,
- Remaining: maxTokens - 1,
- EosID: eosID,
- Sampler: sampler,
- }
- events, err := s.batcher.RegisterDecode(seq)
- if err != nil {
- return sb.String(), err
- }
- for ev := range events {
- if ev.Err != nil {
- return sb.String(), ev.Err
- }
- if ev.Done {
- break
- }
- ids = append(ids, ev.Token)
- piece := s.tok.Decode([]int{ev.Token})
- sb.WriteString(piece)
- if err := onPiece(piece); err != nil {
- return sb.String(), err
- }
- }
- return sb.String(), nil
- }
- // Decode loop - handle both PagedKVCache and model's cache
- for i := 1; i < maxTokens; i++ {
- if nextToken == eosID {
- break
- }
- select {
- case <-ctx.Done():
- return sb.String(), ctx.Err()
- default:
- }
- input := createInputTensor([]int{nextToken})
- positions := createPositionTensor(len(ids)-1, 1)
- if modelCache != nil {
- // Use model's cache (recurrent state)
- before := modelCache.SeqLen()
- logits, err = s.eng.Forward(ctx, input, positions, modelCache)
- if err != nil {
- return sb.String(), err
- }
- if modelCache.SeqLen() == before {
- modelCache.Commit(1)
- }
- } else {
- before := cache.SeqLen()
- logits, err = s.eng.Forward(ctx, input, positions, cache)
- if err != nil {
- return sb.String(), err
- }
- if cache.SeqLen() == before {
- cache.Commit(1)
- }
- }
- recent := ids
- if len(recent) > 64 {
- recent = recent[len(recent)-64:]
- }
- nextToken, err = sampleNextToken(logits, 0, sampler, recent)
- if err != nil {
- return sb.String(), err
- }
- ids = append(ids, nextToken)
- if pagedCache != nil {
- pagedCache.AppendToken(nextToken)
- }
- piece := s.tok.Decode([]int{nextToken})
- sb.WriteString(piece)
- if err := onPiece(piece); err != nil {
- return sb.String(), err
- }
- }
- return sb.String(), nil
- }
- func createInputTensor(ids []int) tensor.Tensor {
- t := cpu.NewTensor(tensor.Shape{len(ids)}, nil)
- data := t.DataFloat32()
- for i, id := range ids {
- data[i] = float32(id)
- }
- return t
- }
- func createPositionTensor(start, count int) tensor.Tensor {
- t := cpu.NewTensor(tensor.Shape{count}, nil)
- data := t.DataFloat32()
- for i := 0; i < count; i++ {
- data[i] = float32(start + i)
- }
- return t
- }
- func getLogitsRowCPU(logits tensor.Tensor, row int) []float32 {
- if _, ok := logits.(*cpu.Tensor); !ok {
- return nil
- }
- data := logits.Data().(unsafe.Pointer)
- shape := logits.Shape()
- vocabSize := shape[1]
- slice := unsafe.Slice((*float32)(data), shape.NumElements())
- return slice[row*vocabSize : (row+1)*vocabSize]
- }
- func sampleNextToken(logits tensor.Tensor, row int, sampler *sample.Sampler, recentTokens []int) (int, error) {
- if logits == nil {
- return 0, fmt.Errorf("nil logits")
- }
- if sampler == nil {
- return 0, fmt.Errorf("nil sampler")
- }
- recent := recentTokens
- if len(recent) > 64 {
- recent = recent[len(recent)-64:]
- }
- if logitsCPU := getLogitsRowCPU(logits, row); logitsCPU != nil {
- return sampler.Sample(logitsCPU, recent), nil
- }
- gpuLogits, ok := logits.(*cuda.Tensor)
- if !ok {
- return 0, fmt.Errorf("unexpected logits type %T", logits)
- }
- shape := gpuLogits.Shape()
- if len(shape) != 2 {
- return 0, fmt.Errorf("expected 2D logits, got shape %v", shape)
- }
- vocabSize := shape[1]
- view, err := gpuLogits.ViewAt(tensor.Shape{vocabSize}, uintptr(row*vocabSize*4))
- if err != nil {
- return 0, err
- }
- cfg := sampler.Config()
- k := cfg.TopK
- if cfg.Temperature == 0 {
- k = 1
- }
- // CUDA TopK kernel supports k<=64; fall back to full D2H if disabled or too large.
- if k <= 0 || k > 64 {
- host := make([]float32, vocabSize)
- if err := view.CopyToHost(host); err != nil {
- return 0, err
- }
- return sampler.Sample(host, recent), nil
- }
- repPenalty := cfg.RepetitionPenalty
- if repPenalty <= 0 {
- repPenalty = 1.0
- }
- repIDs := make([]int32, len(recent))
- for i, t := range recent {
- repIDs[i] = int32(t)
- }
- allIDs, allScores, blocks, err := cuda.TopKLogitsF32(view.Data().(unsafe.Pointer), vocabSize, repIDs, repPenalty, k, gpuLogits.GPU())
- if err != nil {
- return 0, err
- }
- cands := make([]struct {
- id int32
- score float32
- }, 0, blocks*k)
- for i := 0; i < blocks*k; i++ {
- if allIDs[i] < 0 {
- continue
- }
- cands = append(cands, struct {
- id int32
- score float32
- }{id: allIDs[i], score: allScores[i]})
- }
- if len(cands) == 0 {
- // Defensive fallback.
- host := make([]float32, vocabSize)
- if err := view.CopyToHost(host); err != nil {
- return 0, err
- }
- return sampler.Sample(host, recent), nil
- }
- sort.Slice(cands, func(i, j int) bool { return cands[i].score > cands[j].score })
- if len(cands) > k {
- cands = cands[:k]
- }
- finalIDs := make([]int32, len(cands))
- finalScores := make([]float32, len(cands))
- for i := range cands {
- finalIDs[i] = cands[i].id
- finalScores[i] = cands[i].score
- }
- return sampler.SampleFromTopK(finalIDs, finalScores), nil
- }
|