package openai import ( "context" "errors" "encoding/json" "fmt" "log" "math/rand" "net/http" "sort" "strings" "sync" "time" "unicode/utf8" "unsafe" "makarna/pkg/backend/cpu" "makarna/pkg/backend/cuda" "makarna/pkg/backend/device" "makarna/pkg/chat" "makarna/pkg/compute" "makarna/pkg/engine" "makarna/pkg/kvcache" "makarna/pkg/model" "makarna/pkg/sample" "makarna/pkg/tensor" "makarna/pkg/tokenizer" ) type ChatCompletionRequest struct { Model string `json:"model"` Messages []ChatCompletionMessage `json:"messages"` Tools []any `json:"tools,omitempty"` Stream bool `json:"stream,omitempty"` SessionID string `json:"session_id,omitempty"` MaxTokens int `json:"max_tokens,omitempty"` Temperature float64 `json:"temperature,omitempty"` TopP float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"` FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` } type CompletionRequest struct { Model string `json:"model"` Prompt string `json:"prompt"` Stream bool `json:"stream,omitempty"` MaxTokens int `json:"max_tokens,omitempty"` Temperature float64 `json:"temperature,omitempty"` TopP float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` } type utf8StreamBuffer struct { buf []byte } func (b *utf8StreamBuffer) Push(s string) string { if s == "" { return "" } b.buf = append(b.buf, []byte(s)...) outLen := 0 for outLen < len(b.buf) { _, size := utf8.DecodeRune(b.buf[outLen:]) if size == 0 { break } if size == 1 { if !utf8.FullRune(b.buf[outLen:]) { break } } outLen += size } if outLen == 0 { return "" } out := string(b.buf[:outLen]) b.buf = b.buf[outLen:] return out } func (b *utf8StreamBuffer) Flush() string { if len(b.buf) == 0 { return "" } if !utf8.Valid(b.buf) { b.buf = nil return "" } out := string(b.buf) b.buf = nil return out } type ChatCompletionMessage struct { Role string `json:"role"` Content string `json:"content"` Name string `json:"name,omitempty"` ToolCallID string `json:"tool_call_id,omitempty"` } type ChatCompletionResponse struct { ID string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` Model string `json:"model"` Usage ChatCompletionUsage `json:"usage"` Choices []ChatCompletionChoice `json:"choices"` } type CompletionResponse struct { ID string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` Model string `json:"model"` Choices []CompletionChoice `json:"choices"` Usage CompletionUsage `json:"usage"` } type CompletionChoice struct { Index int `json:"index"` Text string `json:"text"` FinishReason string `json:"finish_reason"` LogProbs any `json:"logprobs"` } type CompletionUsage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` } type openAIErrorResponse struct { Error openAIError `json:"error"` } type openAIError struct { Message string `json:"message"` Type string `json:"type"` Param *string `json:"param,omitempty"` Code *string `json:"code,omitempty"` } type contextLengthExceededError struct { PromptTokens int MaxSeqLen int } func (e *contextLengthExceededError) Error() string { return fmt.Sprintf("context length exceeded: prompt_tokens=%d max_seq_len=%d", e.PromptTokens, e.MaxSeqLen) } type ChatCompletionUsage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` } type ChatCompletionChunkResponse struct { ID string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` Model string `json:"model"` Choices []ChatCompletionChunkChoice `json:"choices"` } type ChatCompletionChunkChoice struct { Index int `json:"index"` Delta ChatDelta `json:"delta"` FinishReason *string `json:"finish_reason"` } type ChatDelta struct { Role string `json:"role,omitempty"` Content string `json:"content,omitempty"` } type ChatCompletionChoice struct { Index int `json:"index"` Message ChatCompletionOutMsg `json:"message"` FinishReason string `json:"finish_reason"` } type ChatCompletionOutMsg struct { Role string `json:"role"` Content string `json:"content"` ToolCalls []OpenAIToolCall `json:"tool_calls,omitempty"` } type OpenAIToolCall struct { ID string `json:"id"` Type string `json:"type"` Function OpenAIFunctionCall `json:"function"` } type OpenAIFunctionCall struct { Name string `json:"name"` Arguments string `json:"arguments"` } type Config struct { Listen string MaxSeqLen int BlockSize int KVCacheCPU bool EnableThinking bool PrefillChunkSize int MaxConcurrent int } type Server struct { eng *engine.Engine batcher *engine.Batcher tok *tokenizer.Tokenizer arch string modelID string maxSeqLen int blockSize int prefillChunkSize int mu sync.Mutex // global block pool for prefix caching (used by standard attention models) blockPool *kvcache.BlockPool blockPoolDevice tensor.DeviceType blockPoolGPU int // cacheFactory is set if model implements model.CacheFactory (e.g., KimiLinear, Mamba) // When set, we use the model's cache instead of PagedKVCache cacheFactory model.CacheFactory scratchPool chan *compute.ScratchSet scratchGPUs []int scratchBytes int enableThinking bool stripTokens []string } func Serve(eng *engine.Engine, tok *tokenizer.Tokenizer, arch string, cfg Config) error { if cfg.Listen == "" { cfg.Listen = ":8080" } if cfg.MaxSeqLen <= 0 { cfg.MaxSeqLen = 8192 } if cfg.BlockSize <= 0 { cfg.BlockSize = 32 } if cfg.PrefillChunkSize <= 0 { cfg.PrefillChunkSize = 512 } if cfg.MaxConcurrent <= 0 { cfg.MaxConcurrent = 4 } s := &Server{ eng: eng, batcher: engine.NewBatcher(eng), tok: tok, arch: arch, modelID: arch, maxSeqLen: cfg.MaxSeqLen, blockSize: cfg.BlockSize, enableThinking: cfg.EnableThinking, prefillChunkSize: cfg.PrefillChunkSize, } // Check if model implements CacheFactory (e.g., KimiLinear, Mamba) if cf, ok := eng.Model().(model.CacheFactory); ok { s.cacheFactory = cf log.Printf("kv_cache: using model's CacheFactory (type=%d)", cf.CacheType()) } // Only start batcher for models using PagedKVCache (standard attention) // Recurrent models don't support batching yet if s.batcher != nil && s.cacheFactory == nil { s.batcher.Start() } else { s.batcher = nil } // Initialize global BlockPool for prefix caching (only for standard attention models) modelCfg := eng.Model().Config() var ( layerPlacements []tensor.DevicePlacement anyCUDALayer bool ) if eng.Dispatcher() != nil { layerPlacements = make([]tensor.DevicePlacement, modelCfg.NumLayers) for i := 0; i < modelCfg.NumLayers; i++ { p := eng.Dispatcher().LayerPlacement(i).Normalize() layerPlacements[i] = p if p.Type == tensor.CUDA && p.GPU >= 0 { anyCUDALayer = true } } } // Skip BlockPool for models using CacheFactory (recurrent state models) if s.cacheFactory == nil { blocksPerRequest := (cfg.MaxSeqLen + cfg.BlockSize - 1) / cfg.BlockSize blocksPerLayer := blocksPerRequest * cfg.MaxConcurrent if blocksPerLayer <= 0 { return fmt.Errorf("invalid KV pool sizing: max_seq_len=%d block_size=%d max_concurrent=%d", cfg.MaxSeqLen, cfg.BlockSize, cfg.MaxConcurrent) } // Support mixed CPU/GPU offload: allocate KV blocks per-layer on the same device as that layer. // This avoids the pathological "KV on CPU forces attention on CPU" slowdown when not all layers fit on GPU. if !cfg.KVCacheCPU && device.CUDAAvailable() && anyCUDALayer { blockPool, err := kvcache.NewBlockPool(kvcache.BlockPoolConfig{ NumLayers: modelCfg.NumLayers, NumKVHeads: modelCfg.NumKVHeads, HeadDim: modelCfg.HeadDim, BlockSize: cfg.BlockSize, NumBlocks: blocksPerLayer, Device: tensor.CUDA, // default; LayerPlacements can still contain CPU layers GPU: 0, LayerPlacements: layerPlacements, Preallocate: true, }) if err != nil { return fmt.Errorf("KV cache reserve failed (max_seq_len=%d max_concurrent=%d): %w", cfg.MaxSeqLen, cfg.MaxConcurrent, err) } s.blockPool = blockPool s.blockPoolDevice = tensor.CUDA s.blockPoolGPU = 0 log.Printf("kv_cache: device=mixed block_size=%d blocks_per_layer=%d", cfg.BlockSize, blocksPerLayer) } else { blockPool, err := kvcache.NewBlockPool(kvcache.BlockPoolConfig{ NumLayers: modelCfg.NumLayers, NumKVHeads: modelCfg.NumKVHeads, HeadDim: modelCfg.HeadDim, BlockSize: cfg.BlockSize, NumBlocks: blocksPerLayer, Device: tensor.CPU, GPU: 0, }) if err != nil { return fmt.Errorf("KV cache init failed: %w", err) } s.blockPool = blockPool s.blockPoolDevice = tensor.CPU s.blockPoolGPU = 0 log.Printf("kv_cache: device=cpu block_size=%d blocks_per_layer=%d", cfg.BlockSize, blocksPerLayer) } } // Allocate scratch sets up-front so prefill/decode doesn't trigger cudaMalloc churn. if device.CUDAAvailable() && eng != nil && eng.Dispatcher() != nil { gpus := collectDispatcherGPUs(eng.Dispatcher(), modelCfg.NumLayers) if len(gpus) > 0 { seqLen := cfg.PrefillChunkSize if seqLen <= 0 { seqLen = 512 } if cfg.MaxSeqLen > 0 && seqLen > cfg.MaxSeqLen { seqLen = cfg.MaxSeqLen } scratchBytes := estimateScratchBytes(modelCfg, seqLen) poolSize := cfg.MaxConcurrent if poolSize <= 0 { poolSize = 1 } const minScratchBytes = 8 << 20 var scratchErr error for scratchBytes >= minScratchBytes { created := make([]*compute.ScratchSet, 0, poolSize) ok := true for i := 0; i < poolSize; i++ { ss, err := compute.NewScratchSet(gpus, scratchBytes) if err != nil { scratchErr = err ok = false break } created = append(created, ss) } if ok { s.scratchPool = make(chan *compute.ScratchSet, poolSize) for _, ss := range created { s.scratchPool <- ss } s.scratchGPUs = gpus s.scratchBytes = scratchBytes log.Printf("scratch: gpus=%v bytes=%d sets=%d", gpus, scratchBytes, poolSize) break } for _, prev := range created { prev.Free() } scratchBytes /= 2 } if s.scratchPool == nil && scratchErr != nil { log.Printf("scratch disabled (reserve failed): %v", scratchErr) } } } // Warm up GPU weight caches so the first request doesn't spend seconds uploading weights. // Run asynchronously so the server can start listening immediately (important for OpenWebUI discovery). if s.scratchPool != nil && s.eng != nil { go func() { ss := <-s.scratchPool defer func() { s.scratchPool <- ss }() warmCtx := context.Background() warmCtx = compute.WithScratchSet(warmCtx, ss) if len(s.scratchGPUs) > 0 { if sc := ss.Scratch(s.scratchGPUs[0]); sc != nil { warmCtx = compute.WithScratch(warmCtx, sc) } } warmSeqLen := cfg.PrefillChunkSize if warmSeqLen <= 0 { warmSeqLen = 512 } if cfg.MaxSeqLen > 0 && warmSeqLen > cfg.MaxSeqLen { warmSeqLen = cfg.MaxSeqLen } warmIDs := make([]int, warmSeqLen) _, err := s.eng.Forward(warmCtx, createInputTensor(warmIDs), createPositionTensor(0, warmSeqLen), nil) if err != nil { log.Printf("warmup forward failed: %v", err) return } log.Printf("warmup forward ok seq_len=%d", warmSeqLen) compute.LogWeightCacheSummary() }() } if tok != nil { for _, st := range tok.AddedTokenStrings() { if strings.Contains(strings.ToLower(st), "think") { continue } s.stripTokens = append(s.stripTokens, st) } } mux := http.NewServeMux() mux.HandleFunc("/v1/chat/completions", s.handleChatCompletions) mux.HandleFunc("/v1/completions", s.handleCompletions) mux.HandleFunc("/v1/models", s.handleModels) mux.HandleFunc("/", s.handleRoot) h := withCORS(mux) log.Printf("listening on %s (arch=%s, cuda=%v)", cfg.Listen, arch, device.CUDAAvailable()) return http.ListenAndServe(cfg.Listen, h) } func withCORS(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) return } next.ServeHTTP(w, r) }) } func (s *Server) handleModels(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } resp := map[string]any{ "object": "list", "data": []any{map[string]any{ "id": s.modelID, "object": "model", "owned_by": "local", }}, } writeJSON(w, resp) } func (s *Server) handleRoot(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/" { http.NotFound(w, r) return } if r.Method != http.MethodGet { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } writeJSON(w, map[string]any{ "status": "ok", "api": map[string]any{ "chat_completions": "/v1/chat/completions", "completions": "/v1/completions", "models": "/v1/models", }, }) } func (s *Server) handleCompletions(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } start := time.Now() var req CompletionRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "bad json", http.StatusBadRequest) return } if req.Model == "" { req.Model = s.modelID } if req.Prompt == "" { http.Error(w, "prompt required", http.StatusBadRequest) return } if req.Stream { http.Error(w, "stream not implemented", http.StatusNotImplemented) return } maxTokens := req.MaxTokens if maxTokens <= 0 { maxTokens = 128 } temp := req.Temperature if temp == 0 { temp = 0.7 } topP := req.TopP if topP == 0 { topP = 0.9 } topK := req.TopK if topK == 0 { topK = 40 } requestID := fmt.Sprintf("cmpl_%d", rand.Int63()) log.Printf("/v1/completions id=%s stream=false remote=%s model=%s max_tokens=%d", requestID, r.RemoteAddr, req.Model, maxTokens) outText, genErr := s.generate(r.Context(), requestID, req.Prompt, maxTokens, temp, topP, topK) if genErr != nil { log.Printf("/v1/completions id=%s error=%v", requestID, genErr) http.Error(w, fmt.Sprintf("generate: %v", genErr), http.StatusInternalServerError) return } promptTokens := 0 completionTokens := 0 if s.tok != nil { promptTokens = len(s.tok.Encode(req.Prompt)) completionTokens = len(s.tok.Encode(outText)) } resp := CompletionResponse{ ID: requestID, Object: "text_completion", Created: time.Now().Unix(), Model: req.Model, Choices: []CompletionChoice{{ Index: 0, Text: outText, FinishReason: "stop", LogProbs: nil, }}, Usage: CompletionUsage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, TotalTokens: promptTokens + completionTokens, }, } log.Printf("/v1/completions id=%s done prompt_tokens=%d completion_tokens=%d ms=%d", requestID, promptTokens, completionTokens, time.Since(start).Milliseconds()) writeJSON(w, resp) } func (s *Server) handleChatCompletions(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } start := time.Now() var req ChatCompletionRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "bad json", http.StatusBadRequest) return } if req.Model == "" { req.Model = s.modelID } if len(req.Messages) == 0 { http.Error(w, "messages required", http.StatusBadRequest) return } if req.Stream { s.handleChatCompletionsStream(w, r, req) return } requestID := fmt.Sprintf("chatcmpl_%d", rand.Int63()) msgs := make([]chat.Message, 0, len(req.Messages)) for _, m := range req.Messages { role := strings.ToLower(m.Role) msgs = append(msgs, chat.Message{Role: role, Content: m.Content}) } prompt, promptTokens, truncated, err := s.renderPromptWithinBudget(msgs, req.Tools) if err != nil { var cl *contextLengthExceededError if errors.As(err, &cl) { writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("This model's maximum context length is %d tokens. However, your messages resulted in %d tokens.", cl.MaxSeqLen, cl.PromptTokens), "context_length_exceeded", "messages") return } http.Error(w, fmt.Sprintf("render prompt: %v", err), http.StatusInternalServerError) return } if truncated { log.Printf("/v1/chat/completions id=%s prompt_truncated=true", requestID) } maxTokens := s.computeMaxTokens(req.MaxTokens, promptTokens) log.Printf("/v1/chat/completions id=%s stream=false remote=%s model=%s messages=%d max_tokens=%d prompt_tokens=%d", requestID, r.RemoteAddr, req.Model, len(req.Messages), maxTokens, promptTokens) temp := req.Temperature if temp == 0 { temp = 0.7 } topP := req.TopP if topP == 0 { topP = 0.9 } topK := req.TopK if topK == 0 { topK = 40 } outText, genErr := s.generate(r.Context(), requestID, prompt, maxTokens, temp, topP, topK) if genErr != nil { log.Printf("/v1/chat/completions id=%s error=%v", requestID, genErr) http.Error(w, fmt.Sprintf("generate: %v", genErr), http.StatusInternalServerError) return } genTokens := 0 if s.tok != nil { genTokens = len(s.tok.Encode(outText)) } content := strings.TrimSpace(outText) content, calls, err := chat.ExtractToolCalls(content) if err != nil { http.Error(w, fmt.Sprintf("parse tool_calls: %v", err), http.StatusInternalServerError) return } content = strings.TrimSpace(content) content = sanitizeAssistantContent(content, s.stripTokens) content = strings.TrimSpace(content) var outCalls []OpenAIToolCall for i, c := range calls { outCalls = append(outCalls, OpenAIToolCall{ ID: fmt.Sprintf("call_%d_%d", time.Now().UnixNano(), i), Type: "function", Function: OpenAIFunctionCall{ Name: c.Name, Arguments: string(bytesOrEmptyObject(c.Arguments)), }, }) } finish := "stop" if len(outCalls) > 0 { finish = "tool_calls" } log.Printf("/v1/chat/completions id=%s done stream=false prompt_tokens=%d gen_tokens=%d tool_calls=%d ms=%d", requestID, promptTokens, genTokens, len(outCalls), time.Since(start).Milliseconds()) resp := ChatCompletionResponse{ ID: requestID, Object: "chat.completion", Created: time.Now().Unix(), Model: req.Model, Usage: ChatCompletionUsage{ PromptTokens: promptTokens, CompletionTokens: genTokens, TotalTokens: promptTokens + genTokens, }, Choices: []ChatCompletionChoice{{ Index: 0, Message: ChatCompletionOutMsg{ Role: "assistant", Content: content, ToolCalls: outCalls, }, FinishReason: finish, }}, } writeJSON(w, resp) } func (s *Server) handleChatCompletionsStream(w http.ResponseWriter, r *http.Request, req ChatCompletionRequest) { flusher, ok := w.(http.Flusher) if !ok { http.Error(w, "streaming unsupported", http.StatusInternalServerError) return } start := time.Now() msg := make([]chat.Message, 0, len(req.Messages)) for _, m := range req.Messages { role := strings.ToLower(m.Role) msg = append(msg, chat.Message{Role: role, Content: m.Content}) } id := fmt.Sprintf("chatcmpl_%d", rand.Int63()) created := time.Now().Unix() prompt, promptTokens, truncated, err := s.renderPromptWithinBudget(msg, req.Tools) if err != nil { var cl *contextLengthExceededError if errors.As(err, &cl) { writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("This model's maximum context length is %d tokens. However, your messages resulted in %d tokens.", cl.MaxSeqLen, cl.PromptTokens), "context_length_exceeded", "messages") return } http.Error(w, fmt.Sprintf("render prompt: %v", err), http.StatusInternalServerError) return } w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") maxTokens := s.computeMaxTokens(req.MaxTokens, promptTokens) temp := req.Temperature if temp == 0 { temp = 0.7 } topP := req.TopP if topP == 0 { topP = 0.9 } topK := req.TopK if topK == 0 { topK = 40 } if truncated { log.Printf("/v1/chat/completions id=%s prompt_truncated=true", id) } log.Printf("/v1/chat/completions id=%s stream=true remote=%s model=%s messages=%d max_tokens=%d prompt_tokens=%d", id, r.RemoteAddr, req.Model, len(req.Messages), maxTokens, promptTokens) if err := writeSSEJSON(w, ChatCompletionChunkResponse{ ID: id, Object: "chat.completion.chunk", Created: created, Model: req.Model, Choices: []ChatCompletionChunkChoice{{ Index: 0, Delta: ChatDelta{Role: "assistant"}, }}, }); err != nil { return } flusher.Flush() genTokens := 0 utf8buf := &utf8StreamBuffer{} _, genErr := s.generateStream(r.Context(), id, prompt, maxTokens, temp, topP, topK, func(piece string) error { genTokens++ // Decode can emit partial UTF-8 sequences across tokens. // Buffer first, then sanitize, otherwise sanitize may corrupt multibyte sequences. piece = utf8buf.Push(piece) if piece == "" { return nil } piece = sanitizeAssistantContent(piece, s.stripTokens) if err := writeSSEJSON(w, ChatCompletionChunkResponse{ ID: id, Object: "chat.completion.chunk", Created: created, Model: req.Model, Choices: []ChatCompletionChunkChoice{{ Index: 0, Delta: ChatDelta{Content: piece}, }}, }); err != nil { return err } flusher.Flush() return nil }) if genErr != nil { log.Printf("/v1/chat/completions id=%s error=%v", id, genErr) if errors.Is(genErr, context.Canceled) || errors.Is(r.Context().Err(), context.Canceled) { finish := "cancelled" _ = writeSSEJSON(w, ChatCompletionChunkResponse{ ID: id, Object: "chat.completion.chunk", Created: created, Model: req.Model, Choices: []ChatCompletionChunkChoice{{ Index: 0, Delta: ChatDelta{}, FinishReason: &finish, }}, }) flusher.Flush() _, _ = fmt.Fprint(w, "data: [DONE]\n\n") flusher.Flush() } return } if tail := utf8buf.Flush(); tail != "" { tail = sanitizeAssistantContent(tail, s.stripTokens) _ = writeSSEJSON(w, ChatCompletionChunkResponse{ ID: id, Object: "chat.completion.chunk", Created: created, Model: req.Model, Choices: []ChatCompletionChunkChoice{{ Index: 0, Delta: ChatDelta{Content: tail}, }}, }) flusher.Flush() } finish := "stop" if err := writeSSEJSON(w, ChatCompletionChunkResponse{ ID: id, Object: "chat.completion.chunk", Created: created, Model: req.Model, Choices: []ChatCompletionChunkChoice{{ Index: 0, Delta: ChatDelta{}, FinishReason: &finish, }}, }); err != nil { return } flusher.Flush() _, _ = fmt.Fprint(w, "data: [DONE]\n\n") flusher.Flush() log.Printf("/v1/chat/completions id=%s done stream=true prompt_tokens=%d gen_tokens=%d ms=%d", id, promptTokens, genTokens, time.Since(start).Milliseconds()) } func bytesOrEmptyObject(b []byte) []byte { if len(b) == 0 { return []byte("{}") } return b } func sanitizeAssistantContent(s string, stripTokens []string) string { if s == "" { return s } for _, tok := range stripTokens { if tok == "" { continue } s = strings.ReplaceAll(s, tok, "") } return s } func (s *Server) computeMaxTokens(reqMaxTokens int, promptTokens int) int { budget := s.maxSeqLen - promptTokens if budget < 1 { budget = 1 } if reqMaxTokens <= 0 { return budget } if reqMaxTokens < 1 { return 1 } if reqMaxTokens > budget { return budget } return reqMaxTokens } func (s *Server) renderPromptWithinBudget(msgs []chat.Message, tools []any) (string, int, bool, error) { p, err := chat.RenderForArchitecture(s.arch, msgs, chat.Options{ AddGenerationPrompt: true, EnableThinking: s.enableThinking, Tools: tools, }) if err != nil { return "", 0, false, err } if s.tok == nil { return p, 0, false, nil } pt := len(s.tok.Encode(p)) if pt > s.maxSeqLen { return "", pt, false, &contextLengthExceededError{PromptTokens: pt, MaxSeqLen: s.maxSeqLen} } return p, pt, false, nil } func writeOpenAIError(w http.ResponseWriter, status int, message string, code string, param string) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) typeStr := "invalid_request_error" paramCopy := param codeCopy := code _ = json.NewEncoder(w).Encode(openAIErrorResponse{ Error: openAIError{ Message: message, Type: typeStr, Param: ¶mCopy, Code: &codeCopy, }, }) } func writeJSON(w http.ResponseWriter, v any) { w.Header().Set("Content-Type", "application/json") enc := json.NewEncoder(w) enc.SetEscapeHTML(false) _ = enc.Encode(v) } func writeSSEJSON(w http.ResponseWriter, v any) error { b, err := json.Marshal(v) if err != nil { return err } _, err = fmt.Fprintf(w, "data: %s\n\n", b) return err } func collectDispatcherGPUs(d *device.DeviceDispatcher, numLayers int) []int { if d == nil || numLayers <= 0 { return nil } seen := make(map[int]struct{}) out := make([]int, 0, 4) for i := 0; i < numLayers; i++ { p := d.LayerPlacement(i).Normalize() if p.Type != tensor.CUDA || p.GPU < 0 { continue } if _, ok := seen[p.GPU]; ok { continue } seen[p.GPU] = struct{}{} out = append(out, p.GPU) } sort.Ints(out) return out } func estimateScratchBytes(cfg *model.Config, seqLen int) int { if cfg == nil || seqLen <= 0 { return compute.DefaultScratchBytes } hiddenSize := cfg.HiddenSize numHeads := cfg.NumHeads numKVHeads := cfg.NumKVHeads if numKVHeads == 0 { numKVHeads = numHeads } headDim := cfg.HeadDim if headDim == 0 && numHeads > 0 { headDim = hiddenSize / numHeads } intermediate := cfg.Intermediate if intermediate == 0 { intermediate = hiddenSize * 4 } qDim := numHeads * headDim kvDim := numKVHeads * headDim add := func(rows, cols int) int64 { if rows <= 0 || cols <= 0 { return 0 } return int64(rows) * int64(cols) * 4 } var bytes int64 // Attention scratch. bytes += add(seqLen, qDim) // qOut bytes += add(seqLen, kvDim) // kOut bytes += add(seqLen, kvDim) // vOut bytes += add(seqLen, qDim) // attnOut bytes += add(seqLen, hiddenSize) // attnProj // MLP scratch. bytes += add(seqLen, intermediate) // gate bytes += add(seqLen, intermediate) // up bytes += add(seqLen, intermediate) // act bytes += add(seqLen, hiddenSize) // mlpOut // Extra space for ptr tables / int32 slices / alignment slack. bytes += 8 << 20 // Round up to 16MB. const align = int64(16 << 20) if rem := bytes % align; rem != 0 { bytes += align - rem } if bytes < int64(compute.DefaultScratchBytes) { bytes = int64(compute.DefaultScratchBytes) } maxInt := int64(int(^uint(0) >> 1)) if bytes > maxInt { bytes = maxInt } return int(bytes) } func (s *Server) acquireScratchSet(ctx context.Context) (*compute.ScratchSet, func(), error) { if s == nil || s.scratchPool == nil { return nil, func() {}, nil } if ctx == nil { ctx = context.Background() } select { case ss := <-s.scratchPool: if ss != nil { ss.Reset() } release := func() { if ss != nil { ss.Reset() } s.scratchPool <- ss } return ss, release, nil case <-ctx.Done(): return nil, nil, ctx.Err() } } func (s *Server) prefill(ctx context.Context, cache kvcache.KVCacheInterface, ids []int) (tensor.Tensor, error) { if len(ids) == 0 { return nil, fmt.Errorf("no prompt tokens to prefill") } chunk := s.prefillChunkSize if chunk <= 0 { chunk = 512 } var logits tensor.Tensor for start := 0; start < len(ids); start += chunk { end := start + chunk if end > len(ids) { end = len(ids) } part := ids[start:end] input := createInputTensor(part) positions := createPositionTensor(cache.SeqLen(), len(part)) before := cache.SeqLen() out, err := s.eng.Forward(ctx, input, positions, cache) if err != nil { return nil, err } logits = out // Some model implementations advance the cache internally. // If the cache didn't advance, commit here. if cache.SeqLen() == before { cache.Commit(len(part)) } } return logits, nil } // prefillWithModelCache processes all tokens at once using the model's custom cache (for recurrent models). // Unlike prefill which chunks tokens, recurrent models like KimiLinear need to process sequentially // to properly update their recurrent state. func (s *Server) prefillWithModelCache(ctx context.Context, cache model.KVCache, ids []int) (tensor.Tensor, error) { if len(ids) == 0 { return nil, fmt.Errorf("no prompt tokens to prefill") } input := createInputTensor(ids) positions := createPositionTensor(cache.SeqLen(), len(ids)) before := cache.SeqLen() logits, err := s.eng.Forward(ctx, input, positions, cache) if err != nil { return nil, err } // Commit if cache didn't advance internally if cache.SeqLen() == before { cache.Commit(len(ids)) } return logits, nil } func (s *Server) generate(ctx context.Context, requestID string, prompt string, maxTokens int, temperature float64, topP float64, topK int) (string, error) { scratchSet, releaseScratch, err := s.acquireScratchSet(ctx) if err != nil { return "", err } if scratchSet != nil { defer releaseScratch() ctx = compute.WithScratchSet(ctx, scratchSet) if len(s.scratchGPUs) > 0 { if sc := scratchSet.Scratch(s.scratchGPUs[0]); sc != nil { ctx = compute.WithScratch(ctx, sc) } } } ids := s.tok.Encode(prompt) if len(ids) == 0 { return "", fmt.Errorf("empty prompt after tokenization") } // Create cache - use model's CacheFactory if available, otherwise PagedKVCache var cache kvcache.KVCacheInterface var pagedCache *kvcache.PagedKVCache var modelCache model.KVCache var cachedTokens int if s.cacheFactory != nil { // Model uses custom cache (e.g., KimiLinear with recurrent state) mc, err := s.cacheFactory.CreateCache() if err != nil { return "", fmt.Errorf("model cache creation failed: %w", err) } modelCache = mc // modelCache implements model.KVCache but not kvcache.KVCacheInterface // We'll use it directly with the model's Forward function } else { // Standard attention model - use PagedKVCache if s.blockPool == nil { return "", fmt.Errorf("BlockPool not initialized") } modelCfg := s.eng.Model().Config() pagedCache = kvcache.NewPagedKVCache(s.blockPool, kvcache.PagedCacheConfig{ NumLayers: modelCfg.NumLayers, NumKVHeads: modelCfg.NumKVHeads, HeadDim: modelCfg.HeadDim, BlockSize: s.blockSize, MaxSeqLen: s.maxSeqLen, Device: s.blockPoolDevice, GPU: s.blockPoolGPU, }, requestID) cachedTokens, err = pagedCache.AllocateForTokens(ids) if err != nil { pagedCache.Free() return "", fmt.Errorf("PagedKVCache alloc failed: %w", err) } defer pagedCache.Free() cache = pagedCache if cachedTokens > 0 { log.Printf("prefix_cache_hit request=%s cached_tokens=%d prompt_tokens=%d", requestID, cachedTokens, len(ids)) } } sampler := sample.New(sample.Config{ Temperature: float32(temperature), TopK: topK, TopP: float32(topP), RepetitionPenalty: 1.1, Seed: -1, }) // Prefill - skip cached tokens if we have a prefix cache hit prefillIDs := ids if cachedTokens > 0 && cachedTokens < len(ids) { prefillIDs = ids[cachedTokens:] } var logits tensor.Tensor if modelCache != nil { // Use model's cache (recurrent state) - no chunked prefill, process all at once logits, err = s.prefillWithModelCache(ctx, modelCache, prefillIDs) } else { logits, err = s.prefill(ctx, cache, prefillIDs) } if err != nil { return "", err } var nextToken int chunk := s.prefillChunkSize if chunk <= 0 { chunk = 512 } lastPartLen := len(prefillIDs) % chunk if lastPartLen == 0 { lastPartLen = min(chunk, len(prefillIDs)) } rowIdx := lastPartLen - 1 if modelCache != nil { // For recurrent models, we processed all tokens at once rowIdx = len(prefillIDs) - 1 } nextToken, err = sampleNextToken(logits, rowIdx, sampler, ids) if err != nil { return "", err } ids = append(ids, nextToken) if pagedCache != nil { pagedCache.AppendToken(nextToken) } var sb strings.Builder sb.WriteString(s.tok.Decode([]int{nextToken})) eosID := s.tok.EosID() useBatcher := s.batcher != nil && pagedCache != nil if useBatcher { if _, ok := s.eng.Model().(model.BatchForwarder); !ok { useBatcher = false } } if useBatcher { seq := &engine.DecodeSequence{ RequestID: requestID, Ctx: ctx, Cache: cache, History: ids, NextInputToken: nextToken, Remaining: maxTokens - 1, EosID: eosID, Sampler: sampler, } events, err := s.batcher.RegisterDecode(seq) if err != nil { return "", err } for ev := range events { if ev.Err != nil { return "", ev.Err } if ev.Done { break } ids = append(ids, ev.Token) sb.WriteString(s.tok.Decode([]int{ev.Token})) } return sb.String(), nil } // Decode loop - handle both PagedKVCache and model's cache for i := 1; i < maxTokens; i++ { if nextToken == eosID { break } select { case <-ctx.Done(): return "", ctx.Err() default: } input := createInputTensor([]int{nextToken}) positions := createPositionTensor(len(ids)-1, 1) if modelCache != nil { // Use model's cache (recurrent state) before := modelCache.SeqLen() logits, err = s.eng.Forward(ctx, input, positions, modelCache) if err != nil { return "", err } if modelCache.SeqLen() == before { modelCache.Commit(1) } } else { before := cache.SeqLen() logits, err = s.eng.Forward(ctx, input, positions, cache) if err != nil { return "", err } if cache.SeqLen() == before { cache.Commit(1) } } recent := ids if len(recent) > 64 { recent = recent[len(recent)-64:] } nextToken, err = sampleNextToken(logits, 0, sampler, recent) if err != nil { return "", err } ids = append(ids, nextToken) if pagedCache != nil { pagedCache.AppendToken(nextToken) } sb.WriteString(s.tok.Decode([]int{nextToken})) } return sb.String(), nil } func (s *Server) generateStream(ctx context.Context, requestID string, prompt string, maxTokens int, temperature float64, topP float64, topK int, onPiece func(string) error) (string, error) { scratchSet, releaseScratch, err := s.acquireScratchSet(ctx) if err != nil { return "", err } if scratchSet != nil { defer releaseScratch() ctx = compute.WithScratchSet(ctx, scratchSet) if len(s.scratchGPUs) > 0 { if sc := scratchSet.Scratch(s.scratchGPUs[0]); sc != nil { ctx = compute.WithScratch(ctx, sc) } } } ids := s.tok.Encode(prompt) if len(ids) == 0 { return "", fmt.Errorf("empty prompt after tokenization") } // Create cache - use model's CacheFactory if available, otherwise PagedKVCache var cache kvcache.KVCacheInterface var pagedCache *kvcache.PagedKVCache var modelCache model.KVCache var cachedTokens int if s.cacheFactory != nil { // Model uses custom cache (e.g., KimiLinear with recurrent state) mc, err := s.cacheFactory.CreateCache() if err != nil { return "", fmt.Errorf("model cache creation failed: %w", err) } modelCache = mc } else { // Standard attention model - use PagedKVCache if s.blockPool == nil { return "", fmt.Errorf("BlockPool not initialized") } modelCfg := s.eng.Model().Config() pagedCache = kvcache.NewPagedKVCache(s.blockPool, kvcache.PagedCacheConfig{ NumLayers: modelCfg.NumLayers, NumKVHeads: modelCfg.NumKVHeads, HeadDim: modelCfg.HeadDim, BlockSize: s.blockSize, MaxSeqLen: s.maxSeqLen, Device: s.blockPoolDevice, GPU: s.blockPoolGPU, }, requestID) cachedTokens, err = pagedCache.AllocateForTokens(ids) if err != nil { pagedCache.Free() return "", fmt.Errorf("PagedKVCache alloc failed: %w", err) } defer pagedCache.Free() cache = pagedCache if cachedTokens > 0 { log.Printf("prefix_cache_hit request=%s cached_tokens=%d prompt_tokens=%d", requestID, cachedTokens, len(ids)) } } sampler := sample.New(sample.Config{ Temperature: float32(temperature), TopK: topK, TopP: float32(topP), RepetitionPenalty: 1.1, Seed: -1, }) // Prefill - skip cached tokens if we have a prefix cache hit prefillIDs := ids if cachedTokens > 0 && cachedTokens < len(ids) { prefillIDs = ids[cachedTokens:] } var logits tensor.Tensor if modelCache != nil { // Use model's cache (recurrent state) - no chunked prefill, process all at once logits, err = s.prefillWithModelCache(ctx, modelCache, prefillIDs) } else { logits, err = s.prefill(ctx, cache, prefillIDs) } if err != nil { return "", err } var nextToken int chunk := s.prefillChunkSize if chunk <= 0 { chunk = 512 } lastPartLen := len(prefillIDs) % chunk if lastPartLen == 0 { lastPartLen = min(chunk, len(prefillIDs)) } rowIdx := lastPartLen - 1 if modelCache != nil { // For recurrent models, we processed all tokens at once rowIdx = len(prefillIDs) - 1 } nextToken, err = sampleNextToken(logits, rowIdx, sampler, ids) if err != nil { return "", err } ids = append(ids, nextToken) if pagedCache != nil { pagedCache.AppendToken(nextToken) } var sb strings.Builder first := s.tok.Decode([]int{nextToken}) sb.WriteString(first) if err := onPiece(first); err != nil { return sb.String(), err } eosID := s.tok.EosID() useBatcher := s.batcher != nil && pagedCache != nil if useBatcher { if _, ok := s.eng.Model().(model.BatchForwarder); !ok { useBatcher = false } } if useBatcher { seq := &engine.DecodeSequence{ RequestID: requestID, Ctx: ctx, Cache: cache, History: ids, NextInputToken: nextToken, Remaining: maxTokens - 1, EosID: eosID, Sampler: sampler, } events, err := s.batcher.RegisterDecode(seq) if err != nil { return sb.String(), err } for ev := range events { if ev.Err != nil { return sb.String(), ev.Err } if ev.Done { break } ids = append(ids, ev.Token) piece := s.tok.Decode([]int{ev.Token}) sb.WriteString(piece) if err := onPiece(piece); err != nil { return sb.String(), err } } return sb.String(), nil } // Decode loop - handle both PagedKVCache and model's cache for i := 1; i < maxTokens; i++ { if nextToken == eosID { break } select { case <-ctx.Done(): return sb.String(), ctx.Err() default: } input := createInputTensor([]int{nextToken}) positions := createPositionTensor(len(ids)-1, 1) if modelCache != nil { // Use model's cache (recurrent state) before := modelCache.SeqLen() logits, err = s.eng.Forward(ctx, input, positions, modelCache) if err != nil { return sb.String(), err } if modelCache.SeqLen() == before { modelCache.Commit(1) } } else { before := cache.SeqLen() logits, err = s.eng.Forward(ctx, input, positions, cache) if err != nil { return sb.String(), err } if cache.SeqLen() == before { cache.Commit(1) } } recent := ids if len(recent) > 64 { recent = recent[len(recent)-64:] } nextToken, err = sampleNextToken(logits, 0, sampler, recent) if err != nil { return sb.String(), err } ids = append(ids, nextToken) if pagedCache != nil { pagedCache.AppendToken(nextToken) } piece := s.tok.Decode([]int{nextToken}) sb.WriteString(piece) if err := onPiece(piece); err != nil { return sb.String(), err } } return sb.String(), nil } func createInputTensor(ids []int) tensor.Tensor { t := cpu.NewTensor(tensor.Shape{len(ids)}, nil) data := t.DataFloat32() for i, id := range ids { data[i] = float32(id) } return t } func createPositionTensor(start, count int) tensor.Tensor { t := cpu.NewTensor(tensor.Shape{count}, nil) data := t.DataFloat32() for i := 0; i < count; i++ { data[i] = float32(start + i) } return t } func getLogitsRowCPU(logits tensor.Tensor, row int) []float32 { if _, ok := logits.(*cpu.Tensor); !ok { return nil } data := logits.Data().(unsafe.Pointer) shape := logits.Shape() vocabSize := shape[1] slice := unsafe.Slice((*float32)(data), shape.NumElements()) return slice[row*vocabSize : (row+1)*vocabSize] } func sampleNextToken(logits tensor.Tensor, row int, sampler *sample.Sampler, recentTokens []int) (int, error) { if logits == nil { return 0, fmt.Errorf("nil logits") } if sampler == nil { return 0, fmt.Errorf("nil sampler") } recent := recentTokens if len(recent) > 64 { recent = recent[len(recent)-64:] } if logitsCPU := getLogitsRowCPU(logits, row); logitsCPU != nil { return sampler.Sample(logitsCPU, recent), nil } gpuLogits, ok := logits.(*cuda.Tensor) if !ok { return 0, fmt.Errorf("unexpected logits type %T", logits) } shape := gpuLogits.Shape() if len(shape) != 2 { return 0, fmt.Errorf("expected 2D logits, got shape %v", shape) } vocabSize := shape[1] view, err := gpuLogits.ViewAt(tensor.Shape{vocabSize}, uintptr(row*vocabSize*4)) if err != nil { return 0, err } cfg := sampler.Config() k := cfg.TopK if cfg.Temperature == 0 { k = 1 } // CUDA TopK kernel supports k<=64; fall back to full D2H if disabled or too large. if k <= 0 || k > 64 { host := make([]float32, vocabSize) if err := view.CopyToHost(host); err != nil { return 0, err } return sampler.Sample(host, recent), nil } repPenalty := cfg.RepetitionPenalty if repPenalty <= 0 { repPenalty = 1.0 } repIDs := make([]int32, len(recent)) for i, t := range recent { repIDs[i] = int32(t) } allIDs, allScores, blocks, err := cuda.TopKLogitsF32(view.Data().(unsafe.Pointer), vocabSize, repIDs, repPenalty, k, gpuLogits.GPU()) if err != nil { return 0, err } cands := make([]struct { id int32 score float32 }, 0, blocks*k) for i := 0; i < blocks*k; i++ { if allIDs[i] < 0 { continue } cands = append(cands, struct { id int32 score float32 }{id: allIDs[i], score: allScores[i]}) } if len(cands) == 0 { // Defensive fallback. host := make([]float32, vocabSize) if err := view.CopyToHost(host); err != nil { return 0, err } return sampler.Sample(host, recent), nil } sort.Slice(cands, func(i, j int) bool { return cands[i].score > cands[j].score }) if len(cands) > k { cands = cands[:k] } finalIDs := make([]int32, len(cands)) finalScores := make([]float32, len(cands)) for i := range cands { finalIDs[i] = cands[i].id finalScores[i] = cands[i].score } return sampler.SampleFromTopK(finalIDs, finalScores), nil }