1
0

main.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. package main
  2. import (
  3. "context"
  4. "encoding/json"
  5. "flag"
  6. "fmt"
  7. "log"
  8. "math/rand"
  9. "net/http"
  10. "strings"
  11. "time"
  12. "unsafe"
  13. "makarna/pkg/backend/cpu"
  14. "makarna/pkg/backend/cuda"
  15. "makarna/pkg/backend/device"
  16. "makarna/pkg/chat"
  17. "makarna/pkg/engine"
  18. "makarna/pkg/kvcache"
  19. "makarna/pkg/sample"
  20. "makarna/pkg/tensor"
  21. "makarna/pkg/tokenizer"
  22. )
  23. type chatCompletionRequest struct {
  24. Model string `json:"model"`
  25. Messages []chatCompletionMessage `json:"messages"`
  26. Tools []any `json:"tools,omitempty"`
  27. Stream bool `json:"stream,omitempty"`
  28. MaxTokens int `json:"max_tokens,omitempty"`
  29. Temperature float64 `json:"temperature,omitempty"`
  30. TopP float64 `json:"top_p,omitempty"`
  31. TopK int `json:"top_k,omitempty"`
  32. PresencePenalty float64 `json:"presence_penalty,omitempty"`
  33. FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
  34. }
  35. type chatCompletionMessage struct {
  36. Role string `json:"role"`
  37. Content string `json:"content"`
  38. Name string `json:"name,omitempty"`
  39. ToolCallID string `json:"tool_call_id,omitempty"`
  40. }
  41. type chatCompletionResponse struct {
  42. ID string `json:"id"`
  43. Object string `json:"object"`
  44. Created int64 `json:"created"`
  45. Model string `json:"model"`
  46. Usage chatCompletionUsage `json:"usage"`
  47. Choices []chatCompletionChoice `json:"choices"`
  48. }
  49. type chatCompletionUsage struct {
  50. PromptTokens int `json:"prompt_tokens"`
  51. CompletionTokens int `json:"completion_tokens"`
  52. TotalTokens int `json:"total_tokens"`
  53. }
  54. type chatCompletionChoice struct {
  55. Index int `json:"index"`
  56. Message chatCompletionOutMsg `json:"message"`
  57. FinishReason string `json:"finish_reason"`
  58. }
  59. type chatCompletionOutMsg struct {
  60. Role string `json:"role"`
  61. Content string `json:"content"`
  62. ToolCalls []openAIToolCall `json:"tool_calls,omitempty"`
  63. }
  64. type openAIToolCall struct {
  65. ID string `json:"id"`
  66. Type string `json:"type"`
  67. Function openAIFunctionCall `json:"function"`
  68. }
  69. type openAIFunctionCall struct {
  70. Name string `json:"name"`
  71. Arguments string `json:"arguments"`
  72. }
  73. type server struct {
  74. eng *engine.Engine
  75. tok *tokenizer.Tokenizer
  76. arch string
  77. maxSeqLen int
  78. blockSize int
  79. }
  80. func main() {
  81. listen := flag.String("listen", ":8080", "listen address")
  82. modelPath := flag.String("model", "model.mak", "Path to .mak model file")
  83. maxSeq := flag.Int("max-seq-len", 8192, "Maximum sequence length to reserve in KV cache")
  84. blockSize := flag.Int("block-size", 32, "KV cache block size")
  85. nGPULayers := flag.Int("n-gpu-layers", -1, "Number of layers to offload to GPU (-1=auto, 0=CPU only)")
  86. gpuBudget := flag.Float64("gpu-budget", 0.9, "Fraction of GPU memory to use (0.0-1.0)")
  87. flag.Parse()
  88. cfg := engine.Config{GPULayers: *nGPULayers, GPUBudget: *gpuBudget}
  89. eng, err := engine.Load(*modelPath, cfg)
  90. if err != nil {
  91. log.Fatalf("load model: %v", err)
  92. }
  93. defer eng.Close()
  94. md := eng.Model().Config()
  95. var tok *tokenizer.Tokenizer
  96. tokData, err := eng.Loader().GetTokenizerData()
  97. if err == nil && len(tokData) > 0 {
  98. tok, err = tokenizer.LoadFromBytes(tokData)
  99. if err != nil {
  100. log.Printf("warning: load embedded tokenizer: %v", err)
  101. }
  102. }
  103. if tok == nil {
  104. log.Fatalf("tokenizer not found in model file")
  105. }
  106. s := &server{eng: eng, tok: tok, arch: md.Architecture, maxSeqLen: *maxSeq, blockSize: *blockSize}
  107. h := http.NewServeMux()
  108. h.HandleFunc("/v1/chat/completions", s.handleChatCompletions)
  109. h.HandleFunc("/v1/models", s.handleModels)
  110. log.Printf("listening on %s (arch=%s, cuda=%v)", *listen, s.arch, device.CUDAAvailable())
  111. log.Fatal(http.ListenAndServe(*listen, h))
  112. }
  113. func (s *server) handleModels(w http.ResponseWriter, r *http.Request) {
  114. if r.Method != http.MethodGet {
  115. http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
  116. return
  117. }
  118. resp := map[string]any{
  119. "object": "list",
  120. "data": []any{map[string]any{"id": "local", "object": "model"}},
  121. }
  122. writeJSON(w, resp)
  123. }
  124. func (s *server) handleChatCompletions(w http.ResponseWriter, r *http.Request) {
  125. if r.Method != http.MethodPost {
  126. http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
  127. return
  128. }
  129. var req chatCompletionRequest
  130. if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
  131. http.Error(w, "bad json", http.StatusBadRequest)
  132. return
  133. }
  134. if req.Stream {
  135. http.Error(w, "stream not implemented", http.StatusNotImplemented)
  136. return
  137. }
  138. if len(req.Messages) == 0 {
  139. http.Error(w, "messages required", http.StatusBadRequest)
  140. return
  141. }
  142. msgs := make([]chat.Message, 0, len(req.Messages))
  143. for _, m := range req.Messages {
  144. role := strings.ToLower(m.Role)
  145. msgs = append(msgs, chat.Message{Role: role, Content: m.Content})
  146. }
  147. prompt, err := chat.RenderForArchitecture(s.arch, msgs, chat.Options{
  148. AddGenerationPrompt: true,
  149. EnableThinking: true,
  150. Tools: req.Tools,
  151. })
  152. if err != nil {
  153. http.Error(w, fmt.Sprintf("render prompt: %v", err), http.StatusInternalServerError)
  154. return
  155. }
  156. promptTokens := len(s.tok.Encode(prompt))
  157. maxTokens := req.MaxTokens
  158. if maxTokens <= 0 {
  159. maxTokens = 128
  160. }
  161. temp := req.Temperature
  162. if temp == 0 {
  163. temp = 0.7
  164. }
  165. topP := req.TopP
  166. if topP == 0 {
  167. topP = 0.9
  168. }
  169. topK := req.TopK
  170. if topK == 0 {
  171. topK = 40
  172. }
  173. outText, err := s.generate(r.Context(), prompt, maxTokens, temp, topP, topK)
  174. if err != nil {
  175. http.Error(w, fmt.Sprintf("generate: %v", err), http.StatusInternalServerError)
  176. return
  177. }
  178. completionTokens := len(s.tok.Encode(outText))
  179. _, content := chat.StripThinking(outText)
  180. content, calls, err := chat.ExtractToolCalls(content)
  181. if err != nil {
  182. http.Error(w, fmt.Sprintf("parse tool_calls: %v", err), http.StatusInternalServerError)
  183. return
  184. }
  185. content = strings.TrimSpace(content)
  186. var outCalls []openAIToolCall
  187. for i, c := range calls {
  188. outCalls = append(outCalls, openAIToolCall{
  189. ID: fmt.Sprintf("call_%d_%d", time.Now().UnixNano(), i),
  190. Type: "function",
  191. Function: openAIFunctionCall{
  192. Name: c.Name,
  193. Arguments: string(bytesOrEmptyObject(c.Arguments)),
  194. },
  195. })
  196. }
  197. finish := "stop"
  198. if len(outCalls) > 0 {
  199. finish = "tool_calls"
  200. }
  201. resp := chatCompletionResponse{
  202. ID: fmt.Sprintf("chatcmpl_%d", rand.Int63()),
  203. Object: "chat.completion",
  204. Created: time.Now().Unix(),
  205. Model: req.Model,
  206. Usage: chatCompletionUsage{
  207. PromptTokens: promptTokens,
  208. CompletionTokens: completionTokens,
  209. TotalTokens: promptTokens + completionTokens,
  210. },
  211. Choices: []chatCompletionChoice{{
  212. Index: 0,
  213. Message: chatCompletionOutMsg{
  214. Role: "assistant",
  215. Content: content,
  216. ToolCalls: outCalls,
  217. },
  218. FinishReason: finish,
  219. }},
  220. }
  221. writeJSON(w, resp)
  222. }
  223. func bytesOrEmptyObject(b []byte) []byte {
  224. if len(b) == 0 {
  225. return []byte("{}")
  226. }
  227. return b
  228. }
  229. func writeJSON(w http.ResponseWriter, v any) {
  230. w.Header().Set("Content-Type", "application/json")
  231. enc := json.NewEncoder(w)
  232. enc.SetEscapeHTML(false)
  233. _ = enc.Encode(v)
  234. }
  235. func (s *server) generate(ctx context.Context, prompt string, maxTokens int, temperature float64, topP float64, topK int) (string, error) {
  236. ids := s.tok.Encode(prompt)
  237. if len(ids) == 0 {
  238. return "", fmt.Errorf("empty prompt after tokenization")
  239. }
  240. modelCfg := s.eng.Model().Config()
  241. placements := make([]tensor.DevicePlacement, modelCfg.NumLayers)
  242. if s.eng.Dispatcher() != nil {
  243. for i := 0; i < modelCfg.NumLayers; i++ {
  244. placements[i] = s.eng.Dispatcher().LayerPlacement(i)
  245. }
  246. }
  247. // Enable mixed per-layer KV cache when any layer is on GPU.
  248. kvDevice := tensor.CPU
  249. if device.CUDAAvailable() {
  250. for i := 0; i < modelCfg.NumLayers && i < len(placements); i++ {
  251. if placements[i].Normalize().Type == tensor.CUDA {
  252. kvDevice = tensor.CUDA
  253. break
  254. }
  255. }
  256. }
  257. pool, err := kvcache.NewBlockPool(kvcache.BlockPoolConfig{
  258. NumLayers: modelCfg.NumLayers,
  259. NumKVHeads: modelCfg.NumKVHeads,
  260. HeadDim: modelCfg.HeadDim,
  261. BlockSize: s.blockSize,
  262. NumBlocks: (s.maxSeqLen + s.blockSize - 1) / s.blockSize,
  263. Device: kvDevice,
  264. GPU: 0,
  265. LayerPlacements: func() []tensor.DevicePlacement {
  266. if kvDevice != tensor.CUDA || len(placements) != modelCfg.NumLayers {
  267. return nil
  268. }
  269. out := make([]tensor.DevicePlacement, modelCfg.NumLayers)
  270. for i := 0; i < modelCfg.NumLayers; i++ {
  271. out[i] = placements[i].Normalize()
  272. }
  273. return out
  274. }(),
  275. Preallocate: kvDevice == tensor.CUDA,
  276. })
  277. if err != nil {
  278. return "", err
  279. }
  280. cache := kvcache.NewPagedKVCache(pool, kvcache.PagedCacheConfig{
  281. NumLayers: modelCfg.NumLayers,
  282. NumKVHeads: modelCfg.NumKVHeads,
  283. HeadDim: modelCfg.HeadDim,
  284. BlockSize: s.blockSize,
  285. MaxSeqLen: s.maxSeqLen,
  286. Device: kvDevice,
  287. GPU: 0,
  288. }, "cmd-openai")
  289. if _, err := cache.AllocateForTokens(ids); err != nil {
  290. cache.Free()
  291. return "", err
  292. }
  293. defer cache.Free()
  294. sampler := sample.New(sample.Config{
  295. Temperature: float32(temperature),
  296. TopK: topK,
  297. TopP: float32(topP),
  298. RepetitionPenalty: 1.1,
  299. Seed: -1,
  300. })
  301. input := createInputTensor(ids)
  302. positions := createPositionTensor(0, len(ids))
  303. logits, err := s.eng.Forward(ctx, input, positions, cache)
  304. if err != nil {
  305. return "", err
  306. }
  307. // sample first token
  308. var nextToken int
  309. if logitsCPU := getLogitsRowCPU(logits, len(ids)-1); logitsCPU != nil {
  310. nextToken = sampler.Sample(logitsCPU, ids)
  311. } else {
  312. gpuLogits := logits.(*cuda.Tensor)
  313. vocabSize := gpuLogits.Shape()[1]
  314. row := len(ids) - 1
  315. view, err := gpuLogits.ViewAt(tensor.Shape{vocabSize}, uintptr(row*vocabSize*4))
  316. if err != nil {
  317. return "", err
  318. }
  319. host := make([]float32, vocabSize)
  320. if err := view.CopyToHost(host); err != nil {
  321. return "", err
  322. }
  323. nextToken = sampler.Sample(host, ids)
  324. }
  325. ids = append(ids, nextToken)
  326. var sb strings.Builder
  327. sb.WriteString(s.tok.Decode([]int{nextToken}))
  328. eosID := s.tok.EosID()
  329. for i := 1; i < maxTokens; i++ {
  330. if nextToken == eosID {
  331. break
  332. }
  333. select {
  334. case <-ctx.Done():
  335. return "", ctx.Err()
  336. default:
  337. }
  338. input = createInputTensor([]int{nextToken})
  339. currentPos := len(ids) - 1
  340. positions = createPositionTensor(currentPos, 1)
  341. logits, err = s.eng.Forward(ctx, input, positions, cache)
  342. if err != nil {
  343. return "", err
  344. }
  345. recent := ids
  346. if len(recent) > 64 {
  347. recent = recent[len(recent)-64:]
  348. }
  349. if logitsCPU := getLogitsRowCPU(logits, 0); logitsCPU != nil {
  350. nextToken = sampler.Sample(logitsCPU, recent)
  351. } else {
  352. gpuLogits := logits.(*cuda.Tensor)
  353. vocabSize := gpuLogits.Shape()[1]
  354. view, err := gpuLogits.ViewAt(tensor.Shape{vocabSize}, 0)
  355. if err != nil {
  356. return "", err
  357. }
  358. host := make([]float32, vocabSize)
  359. if err := view.CopyToHost(host); err != nil {
  360. return "", err
  361. }
  362. nextToken = sampler.Sample(host, recent)
  363. }
  364. ids = append(ids, nextToken)
  365. sb.WriteString(s.tok.Decode([]int{nextToken}))
  366. }
  367. return sb.String(), nil
  368. }
  369. func createInputTensor(ids []int) tensor.Tensor {
  370. t := cpu.NewTensor(tensor.Shape{len(ids)}, nil)
  371. data := t.DataFloat32()
  372. for i, id := range ids {
  373. data[i] = float32(id)
  374. }
  375. return t
  376. }
  377. func createPositionTensor(start, count int) tensor.Tensor {
  378. t := cpu.NewTensor(tensor.Shape{count}, nil)
  379. data := t.DataFloat32()
  380. for i := 0; i < count; i++ {
  381. data[i] = float32(start + i)
  382. }
  383. return t
  384. }
  385. func getLogitsRowCPU(logits tensor.Tensor, row int) []float32 {
  386. if _, ok := logits.(*cpu.Tensor); !ok {
  387. return nil
  388. }
  389. data := logits.Data().(unsafe.Pointer)
  390. shape := logits.Shape()
  391. vocabSize := shape[1]
  392. slice := unsafe.Slice((*float32)(data), shape.NumElements())
  393. return slice[row*vocabSize : (row+1)*vocabSize]
  394. }