main.go 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729
  1. package main
  2. import (
  3. "context"
  4. "encoding/json"
  5. "flag"
  6. "fmt"
  7. "log"
  8. "net"
  9. "path/filepath"
  10. "sort"
  11. "strconv"
  12. "strings"
  13. "time"
  14. "unsafe"
  15. "makarna/pkg/backend/cpu"
  16. "makarna/pkg/backend/cuda"
  17. "makarna/pkg/backend/device"
  18. "makarna/pkg/chat"
  19. "makarna/pkg/compute"
  20. "makarna/pkg/engine"
  21. "makarna/pkg/kvcache"
  22. "makarna/pkg/loader"
  23. "makarna/pkg/openai"
  24. "makarna/pkg/profile"
  25. "makarna/pkg/sample"
  26. "makarna/pkg/model"
  27. "makarna/pkg/tensor"
  28. "makarna/pkg/tokenizer"
  29. kimi_linear "makarna/pkg/model/models/kimi_linear" // Register KimiLinear model
  30. _ "makarna/pkg/model/models/qwen3" // Register Qwen3 model
  31. )
  32. func main() {
  33. modelPath := flag.String("model", "model.mak", "Path to .mak model file")
  34. prompt := flag.String("prompt", "Hello world", "Prompt to generate")
  35. steps := flag.Int("steps", 10, "Number of tokens to generate")
  36. useChat := flag.Bool("chat", false, "Use chat format for prompt")
  37. serverMode := flag.Bool("server", false, "Run OpenAI-compatible HTTP server")
  38. listen := flag.String("listen", "", "Server listen address (e.g. :8080, 0.0.0.0:8080). If set, implies --server")
  39. host := flag.String("host", "127.0.0.1", "Server host (used when --listen is empty)")
  40. port := flag.Int("port", 8080, "Server port (used when --listen is empty)")
  41. temperature := flag.Float64("temp", 0.7, "Sampling temperature (0 = greedy)")
  42. topK := flag.Int("top-k", 40, "Top-K sampling (0 = disabled)")
  43. topP := flag.Float64("top-p", 0.9, "Top-P nucleus sampling (1.0 = disabled)")
  44. repPenalty := flag.Float64("rep-penalty", 1.1, "Repetition penalty (1.0 = disabled)")
  45. threads := flag.Int("threads", -1, "Number of CPU threads to use (default: 90% of cores)")
  46. listTensors := flag.Bool("list-tensors", false, "List tensors in model and exit")
  47. useMmap := flag.Bool("mmap", false, "Use mmap for model weights (default: false)")
  48. // Device placement flags - llama.cpp style
  49. nGPULayers := flag.Int("n-gpu-layers", -1, "Number of layers to offload to GPU (-1=auto, 0=CPU only)")
  50. gpuBudget := flag.Float64("gpu-budget", 0.9, "Fraction of GPU memory to use (0.0-1.0)")
  51. gpuDevicesFlag := flag.String("gpu-devices", "", "Comma-separated GPU device ordinals to use (e.g. 0 or 0,1)")
  52. layerMap := flag.String("layer-map", "", "Advanced: layer placement map, e.g. 0-9:gpu0,10-19:gpu1,20-:cpu")
  53. cpuMoE := flag.Bool("cpu-moe", false, "Keep MoE expert weights on CPU (saves GPU memory for large MoE models)")
  54. blockSize := flag.Int("block-size", 16, "KV cache block size")
  55. maxSeq := flag.Int("max-seq-len", 2048, "Maximum sequence length to reserve in KV cache")
  56. kvCacheCPU := flag.Bool("kv-cache-cpu", false, "Force KV cache to CPU (default: GPU when available)")
  57. prefillChunkSize := flag.Int("prefill-chunk-size", 512, "Prompt prefill chunk size (llama.cpp eval batch size analogue)")
  58. maxConcurrent := flag.Int("max-concurrent", 1, "Server mode: max concurrent sequences to reserve KV/scratch for")
  59. // Profiling flags
  60. profileOn := flag.Bool("profile", false, "Enable profiling summary report (alias for -profile-log=report)")
  61. profileLog := flag.String("profile-log", "", "Enable profiling: 'true'=realtime screen, 'report'=summary only, or file path")
  62. flag.Parse()
  63. // Initialize profiling
  64. if *profileOn && *profileLog == "" {
  65. *profileLog = "report"
  66. }
  67. if *profileLog != "" {
  68. profile.Enable()
  69. switch strings.ToLower(*profileLog) {
  70. case "true", "1", "realtime":
  71. // Realtime output to stderr
  72. profile.SetRealtime(true)
  73. fmt.Println("Profiling enabled: realtime output to stderr")
  74. case "report", "summary":
  75. // Summary only at the end
  76. profile.SetRealtime(false)
  77. fmt.Println("Profiling enabled: summary report at end")
  78. default:
  79. // File path specified
  80. if err := profile.SetLogFile(*profileLog); err != nil {
  81. log.Fatalf("Failed to open profile log file: %v", err)
  82. }
  83. profile.SetRealtime(true)
  84. fmt.Printf("Profiling enabled: logging to %s\n", *profileLog)
  85. }
  86. defer func() {
  87. profile.Report()
  88. profile.Close()
  89. }()
  90. }
  91. cpu.SetMaxThreads(*threads)
  92. var gpuDevices []int
  93. if strings.TrimSpace(*gpuDevicesFlag) != "" {
  94. for _, part := range strings.Split(*gpuDevicesFlag, ",") {
  95. part = strings.TrimSpace(part)
  96. if part == "" {
  97. continue
  98. }
  99. id, err := strconv.Atoi(part)
  100. if err != nil {
  101. log.Fatalf("invalid --gpu-devices entry %q: %v", part, err)
  102. }
  103. gpuDevices = append(gpuDevices, id)
  104. }
  105. if len(gpuDevices) == 0 {
  106. log.Fatalf("invalid --gpu-devices: no devices parsed")
  107. }
  108. }
  109. // Determine engine config
  110. cfg := engine.Config{
  111. GPULayers: *nGPULayers,
  112. GPUBudget: *gpuBudget,
  113. GPUDevices: gpuDevices,
  114. UseMmap: *useMmap,
  115. CPUMoE: *cpuMoE,
  116. }
  117. // Load Model
  118. fmt.Printf("Loading model from %s...\n", *modelPath)
  119. if *listTensors {
  120. md, err := loader.LoadWithOptions(*modelPath, loader.LoadOptions{UseMmap: *useMmap})
  121. if err != nil {
  122. log.Fatalf("Failed to load model: %v", err)
  123. }
  124. defer md.Close()
  125. names := make([]string, 0, len(md.Metadata.Tensors))
  126. for name := range md.Metadata.Tensors {
  127. names = append(names, name)
  128. }
  129. sort.Strings(names)
  130. for _, name := range names {
  131. info := md.Metadata.Tensors[name]
  132. fmt.Printf("%s\t%s\t%v\t%d\n", name, info.DType.String(), info.Shape, info.Size)
  133. }
  134. return
  135. }
  136. eng, err := engine.Load(*modelPath, cfg)
  137. if err != nil {
  138. log.Fatalf("Failed to load model: %v", err)
  139. }
  140. defer eng.Close()
  141. // Show device info
  142. if device.CUDAAvailable() {
  143. fmt.Println("CUDA available: yes")
  144. } else {
  145. fmt.Println("CUDA available: no (CPU only)")
  146. }
  147. modelConfig := eng.Model().Config()
  148. // If layer-map is specified, use that for KV cache placement
  149. var placements []tensor.DevicePlacement
  150. if *layerMap != "" {
  151. placements = parseLayerMap(modelConfig.NumLayers, *layerMap)
  152. } else if eng.Dispatcher() != nil {
  153. // Use dispatcher's placements
  154. placements = make([]tensor.DevicePlacement, modelConfig.NumLayers)
  155. for i := 0; i < modelConfig.NumLayers; i++ {
  156. placements[i] = eng.Dispatcher().LayerPlacement(i)
  157. }
  158. }
  159. fmt.Println("Model loaded successfully!")
  160. // Load Tokenizer
  161. var tok *tokenizer.Tokenizer
  162. tokData, err := eng.Loader().GetTokenizerData()
  163. if err == nil && len(tokData) > 0 {
  164. fmt.Println("Found embedded tokenizer in model file.")
  165. tok, err = tokenizer.LoadFromBytes(tokData)
  166. if err != nil {
  167. log.Printf("Warning: failed to load embedded tokenizer: %v", err)
  168. }
  169. }
  170. if tok == nil {
  171. modelDir := filepath.Dir(*modelPath)
  172. tokPath := filepath.Join(modelDir, "tokenizer.json")
  173. fmt.Printf("Loading tokenizer from %s...\n", tokPath)
  174. tok, err = tokenizer.LoadFromJSON(tokPath)
  175. if err != nil {
  176. log.Printf("Warning: failed to load tokenizer: %v", err)
  177. }
  178. }
  179. // Format prompt (optionally with chat template)
  180. finalPrompt := *prompt
  181. if *useChat {
  182. messages := []chat.Message{{Role: "user", Content: *prompt}}
  183. formatted, err := chat.RenderForArchitecture(modelConfig.Architecture, messages, chat.Options{
  184. AddGenerationPrompt: true,
  185. EnableThinking: true,
  186. })
  187. if err != nil {
  188. log.Fatalf("format prompt failed: %v", err)
  189. }
  190. finalPrompt = formatted
  191. fmt.Printf("Formatted prompt:\n%s\n", finalPrompt)
  192. }
  193. // Server mode: start HTTP server and block.
  194. // We do this after model+tokenizer are loaded so all flags (GPU, KV cache sizes, etc.) apply.
  195. if *listen != "" {
  196. *serverMode = true
  197. }
  198. if *serverMode {
  199. addr := *listen
  200. if addr == "" {
  201. addr = net.JoinHostPort(*host, strconv.Itoa(*port))
  202. }
  203. // Ensure JSON is linked in this binary (avoid unused import when tags change)
  204. _ = json.Valid
  205. err := openai.Serve(eng, tok, modelConfig.Architecture, openai.Config{
  206. Listen: addr,
  207. MaxSeqLen: *maxSeq,
  208. BlockSize: *blockSize,
  209. KVCacheCPU: *kvCacheCPU,
  210. EnableThinking: false,
  211. PrefillChunkSize: *prefillChunkSize,
  212. MaxConcurrent: *maxConcurrent,
  213. })
  214. if err != nil {
  215. log.Fatalf("server failed: %v", err)
  216. }
  217. return
  218. }
  219. // Tokenize prompt
  220. var ids []int
  221. if tok != nil {
  222. ids = tok.Encode(finalPrompt)
  223. fmt.Printf("Tokens: %v\n", ids)
  224. } else {
  225. ids = []int{1, 2, 3}
  226. }
  227. // Initialize KV Cache
  228. var kv model.KVCache
  229. var pagedCache *kvcache.PagedKVCache
  230. if modelConfig.Architecture == "KimiLinearForCausalLM" {
  231. params := modelConfig.Params
  232. lacRaw := params["linear_attn_config"]
  233. lac, _ := lacRaw.(map[string]any)
  234. kdaNumHeads := int(lac["num_heads"].(float64))
  235. kdaHeadDim := int(lac["head_dim"].(float64))
  236. kdaKernel := int(lac["short_conv_kernel_size"].(float64))
  237. mlaNumHeads := int(params["num_attention_heads"].(float64))
  238. qkNope := int(params["qk_nope_head_dim"].(float64))
  239. qkRope := int(params["qk_rope_head_dim"].(float64))
  240. vDim := int(params["v_head_dim"].(float64))
  241. kimiCache, err := kimi_linear.NewKimiCache(modelConfig.NumLayers, kdaNumHeads, kdaHeadDim, kdaKernel, mlaNumHeads, qkNope+qkRope, vDim)
  242. if err != nil {
  243. log.Fatalf("KimiCache alloc failed: %v", err)
  244. }
  245. kv = kimiCache
  246. fmt.Println("KV cache: KimiCache (CPU)")
  247. } else {
  248. // Default: enable GPU KV per-layer when ANY layer is on GPU (mixed offload supported),
  249. // unless --kv-cache-cpu is specified.
  250. kvDevice := tensor.CPU
  251. if !*kvCacheCPU && device.CUDAAvailable() {
  252. for i := 0; i < modelConfig.NumLayers && i < len(placements); i++ {
  253. if placements[i].Normalize().Type == tensor.CUDA {
  254. kvDevice = tensor.CUDA
  255. break
  256. }
  257. }
  258. }
  259. switch kvDevice {
  260. case tensor.CUDA:
  261. fmt.Println("KV cache: mixed (per-layer)")
  262. default:
  263. fmt.Println("KV cache: CPU")
  264. }
  265. pool, err := kvcache.NewBlockPool(kvcache.BlockPoolConfig{
  266. NumLayers: modelConfig.NumLayers,
  267. NumKVHeads: modelConfig.NumKVHeads,
  268. HeadDim: modelConfig.HeadDim,
  269. BlockSize: *blockSize,
  270. NumBlocks: (*maxSeq + *blockSize - 1) / (*blockSize),
  271. Device: kvDevice,
  272. GPU: 0,
  273. LayerPlacements: func() []tensor.DevicePlacement {
  274. if kvDevice != tensor.CUDA || len(placements) != modelConfig.NumLayers {
  275. return nil
  276. }
  277. out := make([]tensor.DevicePlacement, modelConfig.NumLayers)
  278. for i := 0; i < modelConfig.NumLayers; i++ {
  279. out[i] = placements[i].Normalize()
  280. }
  281. return out
  282. }(),
  283. Preallocate: kvDevice == tensor.CUDA,
  284. })
  285. if err != nil {
  286. log.Fatalf("NewBlockPool failed: %v", err)
  287. }
  288. pagedCache = kvcache.NewPagedKVCache(pool, kvcache.PagedCacheConfig{
  289. NumLayers: modelConfig.NumLayers,
  290. NumKVHeads: modelConfig.NumKVHeads,
  291. HeadDim: modelConfig.HeadDim,
  292. BlockSize: *blockSize,
  293. MaxSeqLen: *maxSeq,
  294. Device: kvDevice,
  295. GPU: 0,
  296. }, "run-model")
  297. if _, err := pagedCache.AllocateForTokens(ids); err != nil {
  298. pagedCache.Free()
  299. log.Fatalf("PagedKVCache alloc failed: %v", err)
  300. }
  301. defer pagedCache.Free()
  302. kv = pagedCache
  303. }
  304. // Preallocate scratch buffers once so prefill doesn't hit cudaMalloc churn.
  305. runCtx := context.Background()
  306. needWarmup := false
  307. if device.CUDAAvailable() && eng.Dispatcher() != nil && cuda.Available() {
  308. gpuSeen := make(map[int]struct{})
  309. var gpus []int
  310. for i := 0; i < modelConfig.NumLayers; i++ {
  311. p := eng.Dispatcher().LayerPlacement(i).Normalize()
  312. if p.Type != tensor.CUDA || p.GPU < 0 {
  313. continue
  314. }
  315. if _, ok := gpuSeen[p.GPU]; ok {
  316. continue
  317. }
  318. gpuSeen[p.GPU] = struct{}{}
  319. gpus = append(gpus, p.GPU)
  320. }
  321. if len(gpus) > 0 {
  322. const minScratchBytes = 8 << 20
  323. var (
  324. ss *compute.ScratchSet
  325. scratchErr error
  326. scratchSize = compute.DefaultScratchBytes
  327. )
  328. for scratchSize >= minScratchBytes {
  329. var err error
  330. ss, err = compute.NewScratchSet(gpus, scratchSize)
  331. if err == nil {
  332. break
  333. }
  334. scratchErr = err
  335. ss = nil
  336. scratchSize /= 2
  337. }
  338. if ss != nil {
  339. defer ss.Free()
  340. runCtx = compute.WithScratchSet(runCtx, ss)
  341. runCtx = compute.WithScratch(runCtx, ss.Scratch(gpus[0]))
  342. needWarmup = true
  343. log.Printf("scratch: gpus=%v bytes=%d", gpus, scratchSize)
  344. } else if scratchErr != nil {
  345. log.Printf("scratch disabled (alloc failed): %v", scratchErr)
  346. }
  347. }
  348. }
  349. if needWarmup {
  350. if _, err := eng.Forward(runCtx, createInputTensor([]int{0}), createPositionTensor(0, 1), nil); err != nil {
  351. log.Fatalf("warmup forward failed: %v", err)
  352. }
  353. compute.LogWeightCacheSummary()
  354. }
  355. // Initialize Sampler
  356. sampler := sample.New(sample.Config{
  357. Temperature: float32(*temperature),
  358. TopK: *topK,
  359. TopP: float32(*topP),
  360. RepetitionPenalty: float32(*repPenalty),
  361. Seed: -1,
  362. })
  363. // Prefill prompt in chunks (llama.cpp eval batch size analogue).
  364. chunk := *prefillChunkSize
  365. if chunk <= 0 {
  366. chunk = 512
  367. }
  368. var logits tensor.Tensor
  369. for start := 0; start < len(ids); start += chunk {
  370. end := start + chunk
  371. if end > len(ids) {
  372. end = len(ids)
  373. }
  374. part := ids[start:end]
  375. input := createInputTensor(part)
  376. positions := createPositionTensor(kv.SeqLen(), len(part))
  377. before := kv.SeqLen()
  378. profile.Start("Prefill/Forward")
  379. out, err := eng.Forward(runCtx, input, positions, kv)
  380. profile.End("Prefill/Forward")
  381. if err != nil {
  382. log.Fatalf("Prefill forward failed: %v", err)
  383. }
  384. logits = out
  385. if kv.SeqLen() == before {
  386. kv.Commit(len(part))
  387. }
  388. }
  389. if logits == nil {
  390. log.Fatalf("prefill produced nil logits")
  391. }
  392. // Sample first generated token
  393. lastPartLen := len(ids) % chunk
  394. if lastPartLen == 0 {
  395. if chunk < len(ids) {
  396. lastPartLen = chunk
  397. } else {
  398. lastPartLen = len(ids)
  399. }
  400. }
  401. rowIdx := lastPartLen - 1
  402. logitsSlice := getLogitsRowCPU(logits, rowIdx)
  403. var nextToken int
  404. if logitsSlice != nil {
  405. profile.Start("Prefill/Sample")
  406. nextToken = sampler.Sample(logitsSlice, ids)
  407. profile.End("Prefill/Sample")
  408. } else {
  409. // CUDA path: take top-k (or argmax) from GPU, copy only small candidate list.
  410. gpuLogits := logits.(*cuda.Tensor)
  411. vocabSize := gpuLogits.Shape()[1]
  412. view, err := gpuLogits.ViewAt(tensor.Shape{vocabSize}, uintptr(rowIdx*vocabSize*4))
  413. if err != nil {
  414. log.Fatalf("logits view failed: %v", err)
  415. }
  416. k := *topK
  417. if *temperature == 0 {
  418. k = 1
  419. }
  420. if k <= 0 {
  421. // Semantics-preserving fallback: copy full logits row to CPU and use existing sampler.
  422. host := make([]float32, vocabSize)
  423. profile.Start("Prefill/LogitsD2H")
  424. if err := view.CopyToHost(host); err != nil {
  425. log.Fatalf("logits D2H failed: %v", err)
  426. }
  427. profile.End("Prefill/LogitsD2H")
  428. profile.Start("Prefill/Sample")
  429. nextToken = sampler.Sample(host, ids)
  430. profile.End("Prefill/Sample")
  431. goto sampledPrefill
  432. }
  433. recent := ids
  434. if len(recent) > 64 {
  435. recent = recent[len(recent)-64:]
  436. }
  437. repIDs := make([]int32, len(recent))
  438. for i, t := range recent {
  439. repIDs[i] = int32(t)
  440. }
  441. profile.Start("Prefill/TopK")
  442. allIDs, allScores, blocks, err := cuda.TopKLogitsF32(view.Data().(unsafe.Pointer), vocabSize, repIDs, float32(*repPenalty), k, gpuLogits.GPU())
  443. profile.End("Prefill/TopK")
  444. if err != nil {
  445. log.Fatalf("cuda topk failed: %v", err)
  446. }
  447. // Merge per-block candidates on CPU to get global top-k
  448. cands := make([]struct {
  449. id int32
  450. score float32
  451. }, 0, blocks*k)
  452. for i := 0; i < blocks*k; i++ {
  453. if allIDs[i] < 0 {
  454. continue
  455. }
  456. cands = append(cands, struct {
  457. id int32
  458. score float32
  459. }{id: allIDs[i], score: allScores[i]})
  460. }
  461. sort.Slice(cands, func(i, j int) bool { return cands[i].score > cands[j].score })
  462. if len(cands) > k {
  463. cands = cands[:k]
  464. }
  465. finalIDs := make([]int32, len(cands))
  466. finalScores := make([]float32, len(cands))
  467. for i := range cands {
  468. finalIDs[i] = cands[i].id
  469. finalScores[i] = cands[i].score
  470. }
  471. profile.Start("Prefill/Sample")
  472. nextToken = sampler.SampleFromTopK(finalIDs, finalScores)
  473. profile.End("Prefill/Sample")
  474. }
  475. sampledPrefill:
  476. if tok != nil {
  477. fmt.Print(tok.Decode([]int{nextToken}))
  478. }
  479. ids = append(ids, nextToken)
  480. if pagedCache != nil {
  481. pagedCache.AppendToken(nextToken)
  482. }
  483. // Autoregressive generation with KV Cache
  484. eosID := 151645 // <|im_end|>
  485. if tok != nil {
  486. eosID = tok.EosID()
  487. }
  488. startGen := time.Now()
  489. genTokens := 0
  490. for i := 1; i < *steps; i++ {
  491. profile.TokenStart()
  492. // Check for EOS
  493. if nextToken == eosID {
  494. profile.TokenEnd()
  495. break
  496. }
  497. // Prepare single token input
  498. input := createInputTensor([]int{nextToken})
  499. currentPos := len(ids) - 1
  500. positions := createPositionTensor(currentPos, 1)
  501. profile.Start("Decode/Forward")
  502. logits, err = eng.Forward(runCtx, input, positions, kv)
  503. profile.End("Decode/Forward")
  504. if err != nil {
  505. log.Fatalf("Forward failed: %v", err)
  506. }
  507. // Sample with recent context for repetition penalty
  508. logitsSlice = getLogitsRowCPU(logits, 0)
  509. recentTokens := ids
  510. if len(ids) > 64 {
  511. recentTokens = ids[len(ids)-64:]
  512. }
  513. if logitsSlice != nil {
  514. profile.Start("Decode/Sample")
  515. nextToken = sampler.Sample(logitsSlice, recentTokens)
  516. profile.End("Decode/Sample")
  517. } else {
  518. gpuLogits := logits.(*cuda.Tensor)
  519. vocabSize := gpuLogits.Shape()[1]
  520. view, err := gpuLogits.ViewAt(tensor.Shape{vocabSize}, 0)
  521. if err != nil {
  522. log.Fatalf("logits view failed: %v", err)
  523. }
  524. k := *topK
  525. if *temperature == 0 {
  526. k = 1
  527. }
  528. if k <= 0 {
  529. host := make([]float32, vocabSize)
  530. profile.Start("Decode/LogitsD2H")
  531. if err := view.CopyToHost(host); err != nil {
  532. log.Fatalf("logits D2H failed: %v", err)
  533. }
  534. profile.End("Decode/LogitsD2H")
  535. profile.Start("Decode/Sample")
  536. nextToken = sampler.Sample(host, recentTokens)
  537. profile.End("Decode/Sample")
  538. goto sampledDecode
  539. }
  540. repIDs := make([]int32, len(recentTokens))
  541. for i, t := range recentTokens {
  542. repIDs[i] = int32(t)
  543. }
  544. profile.Start("Decode/TopK")
  545. allIDs, allScores, blocks, err := cuda.TopKLogitsF32(view.Data().(unsafe.Pointer), vocabSize, repIDs, float32(*repPenalty), k, gpuLogits.GPU())
  546. profile.End("Decode/TopK")
  547. if err != nil {
  548. log.Fatalf("cuda topk failed: %v", err)
  549. }
  550. cands := make([]struct {
  551. id int32
  552. score float32
  553. }, 0, blocks*k)
  554. for i := 0; i < blocks*k; i++ {
  555. if allIDs[i] < 0 {
  556. continue
  557. }
  558. cands = append(cands, struct {
  559. id int32
  560. score float32
  561. }{id: allIDs[i], score: allScores[i]})
  562. }
  563. sort.Slice(cands, func(i, j int) bool { return cands[i].score > cands[j].score })
  564. if len(cands) > k {
  565. cands = cands[:k]
  566. }
  567. finalIDs := make([]int32, len(cands))
  568. finalScores := make([]float32, len(cands))
  569. for i := range cands {
  570. finalIDs[i] = cands[i].id
  571. finalScores[i] = cands[i].score
  572. }
  573. profile.Start("Decode/Sample")
  574. nextToken = sampler.SampleFromTopK(finalIDs, finalScores)
  575. profile.End("Decode/Sample")
  576. }
  577. sampledDecode:
  578. if tok != nil {
  579. fmt.Print(tok.Decode([]int{nextToken}))
  580. }
  581. ids = append(ids, nextToken)
  582. if pagedCache != nil {
  583. pagedCache.AppendToken(nextToken)
  584. }
  585. genTokens++
  586. profile.TokenEnd()
  587. }
  588. duration := time.Since(startGen)
  589. fmt.Printf("\n\nDone. Generated %d tokens in %v (%.2f tok/s)\n", genTokens, duration, float64(genTokens)/duration.Seconds())
  590. }
  591. func createInputTensor(ids []int) tensor.Tensor {
  592. t := cpu.NewTensor(tensor.Shape{len(ids)}, nil)
  593. data := t.DataFloat32()
  594. for i, id := range ids {
  595. data[i] = float32(id)
  596. }
  597. return t
  598. }
  599. func createPositionTensor(start, count int) tensor.Tensor {
  600. t := cpu.NewTensor(tensor.Shape{count}, nil)
  601. data := t.DataFloat32()
  602. for i := 0; i < count; i++ {
  603. data[i] = float32(start + i)
  604. }
  605. return t
  606. }
  607. func getLogitsRowCPU(logits tensor.Tensor, row int) []float32 {
  608. if _, ok := logits.(*cpu.Tensor); !ok {
  609. return nil
  610. }
  611. data := logits.Data().(unsafe.Pointer)
  612. shape := logits.Shape()
  613. vocabSize := shape[1]
  614. slice := unsafe.Slice((*float32)(data), shape.NumElements())
  615. return slice[row*vocabSize : (row+1)*vocabSize]
  616. }
  617. // parseLayerMap parses a comma-separated placement string like
  618. // "0-9:gpu0,10-19:gpu1,20-:cpu" and returns per-layer placements.
  619. func parseLayerMap(numLayers int, spec string) []tensor.DevicePlacement {
  620. placements := make([]tensor.DevicePlacement, numLayers)
  621. for i := range placements {
  622. placements[i] = tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}
  623. }
  624. if spec == "" {
  625. return placements
  626. }
  627. entries := strings.Split(spec, ",")
  628. for _, entry := range entries {
  629. entry = strings.TrimSpace(entry)
  630. if entry == "" {
  631. continue
  632. }
  633. parts := strings.Split(entry, ":")
  634. if len(parts) != 2 {
  635. log.Printf("invalid layer-map entry %q, skipping", entry)
  636. continue
  637. }
  638. rng, target := strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])
  639. start, end := 0, numLayers-1
  640. if rng != "" {
  641. if strings.Contains(rng, "-") {
  642. rp := strings.SplitN(rng, "-", 2)
  643. if rp[0] != "" {
  644. if v, err := strconv.Atoi(rp[0]); err == nil {
  645. start = v
  646. }
  647. }
  648. if rp[1] != "" {
  649. if v, err := strconv.Atoi(rp[1]); err == nil {
  650. end = v
  651. }
  652. }
  653. } else if v, err := strconv.Atoi(rng); err == nil {
  654. start, end = v, v
  655. }
  656. }
  657. if start < 0 {
  658. start = 0
  659. }
  660. if end >= numLayers {
  661. end = numLayers - 1
  662. }
  663. var placement tensor.DevicePlacement
  664. switch {
  665. case strings.HasPrefix(strings.ToLower(target), "gpu"):
  666. idStr := strings.TrimPrefix(strings.ToLower(target), "gpu")
  667. id := 0
  668. if idStr != "" {
  669. if v, err := strconv.Atoi(idStr); err == nil {
  670. id = v
  671. }
  672. }
  673. placement = tensor.DevicePlacement{Type: tensor.CUDA, GPU: id}.Normalize()
  674. case strings.ToLower(target) == "cpu":
  675. placement = tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}
  676. default:
  677. log.Printf("unknown target %q, defaulting to CPU", target)
  678. placement = tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}
  679. }
  680. for i := start; i <= end && i < numLayers; i++ {
  681. placements[i] = placement
  682. }
  683. }
  684. return placements
  685. }