ソースを参照

Initial commit

Çetin 3 週間 前
コミット
b6c3eecc74
100 ファイル変更17219 行追加0 行削除
  1. 33 0
      .gitignore
  2. 154 0
      Makefile
  3. 115 0
      README.md
  4. 107 0
      cmd/convert/main.go
  5. 440 0
      cmd/openai/main.go
  6. 212 0
      cmd/quantize/main.go
  7. 729 0
      cmd/run-model/main.go
  8. 26 0
      cmd/run-model/main_test.go
  9. 8 0
      go.mod
  10. 12 0
      go.sum
  11. 145 0
      pkg/backend/cpu/affinity_linux.go
  12. 7 0
      pkg/backend/cpu/affinity_other.go
  13. 92 0
      pkg/backend/cpu/cpufeat.go
  14. 130 0
      pkg/backend/cpu/matmul/gemm_blocked.go
  15. 202 0
      pkg/backend/cpu/matmul/gemv_f32_tile8_avx2.s
  16. 215 0
      pkg/backend/cpu/matmul/gemv_f32_tile8_avx512.s
  17. 109 0
      pkg/backend/cpu/matmul/gemv_f32_tiled_amd64.go
  18. 17 0
      pkg/backend/cpu/matmul/gemv_f32_tiled_generic.go
  19. 68 0
      pkg/backend/cpu/matmul/gemv_f32_tiled_test.go
  20. 13 0
      pkg/backend/cpu/matmul/linear.go
  21. 137 0
      pkg/backend/cpu/matmul/linear_bench_test.go
  22. 71 0
      pkg/backend/cpu/matmul/linear_cuda.go
  23. 50 0
      pkg/backend/cpu/matmul/linear_cuda_test.go
  24. 688 0
      pkg/backend/cpu/matmul/linear_shared.go
  25. 94 0
      pkg/backend/cpu/nn/activations.go
  26. 164 0
      pkg/backend/cpu/nn/attention.go
  27. 192 0
      pkg/backend/cpu/nn/attention_batch.go
  28. 102 0
      pkg/backend/cpu/nn/attention_batch_test.go
  29. 454 0
      pkg/backend/cpu/nn/attention_cached.go
  30. 108 0
      pkg/backend/cpu/nn/attention_cached_kv.go
  31. 117 0
      pkg/backend/cpu/nn/attention_cached_test.go
  32. 128 0
      pkg/backend/cpu/nn/conv1d.go
  33. 229 0
      pkg/backend/cpu/nn/embedding.go
  34. 42 0
      pkg/backend/cpu/nn/input.go
  35. 114 0
      pkg/backend/cpu/nn/kda.go
  36. 32 0
      pkg/backend/cpu/nn/mlp.go
  37. 181 0
      pkg/backend/cpu/nn/moe.go
  38. 168 0
      pkg/backend/cpu/nn/nn_bench_test.go
  39. 84 0
      pkg/backend/cpu/nn/nn_simd_test.go
  40. 32 0
      pkg/backend/cpu/nn/rmsnorm.go
  41. 69 0
      pkg/backend/cpu/nn/rope.go
  42. 97 0
      pkg/backend/cpu/nn/self_attention.go
  43. 72 0
      pkg/backend/cpu/nn/silu.go
  44. 15 0
      pkg/backend/cpu/nn/silu_amd64.go
  45. 133 0
      pkg/backend/cpu/nn/silu_avx2.s
  46. 87 0
      pkg/backend/cpu/nn/silu_avx512.s
  47. 12 0
      pkg/backend/cpu/nn/silu_noasm.go
  48. 72 0
      pkg/backend/cpu/nn/softmax.go
  49. 94 0
      pkg/backend/cpu/nn/softmax_amd64.go
  50. 160 0
      pkg/backend/cpu/nn/softmax_avx2.s
  51. 158 0
      pkg/backend/cpu/nn/softmax_avx512.s
  52. 11 0
      pkg/backend/cpu/nn/softmax_noasm.go
  53. 21 0
      pkg/backend/cpu/nn/tensor_utils.go
  54. 39 0
      pkg/backend/cpu/nn/transformer.go
  55. 17 0
      pkg/backend/cpu/ops/add.go
  56. 11 0
      pkg/backend/cpu/ops/copy.go
  57. 42 0
      pkg/backend/cpu/ops/factory.go
  58. 14 0
      pkg/backend/cpu/ops/mul.go
  59. 58 0
      pkg/backend/cpu/ops/permute.go
  60. 95 0
      pkg/backend/cpu/ops/repeat.go
  61. 45 0
      pkg/backend/cpu/ops/reshape.go
  62. 49 0
      pkg/backend/cpu/ops/scalar.go
  63. 121 0
      pkg/backend/cpu/ops/slice.go
  64. 99 0
      pkg/backend/cpu/simd.go
  65. 28 0
      pkg/backend/cpu/simd_avx2.go
  66. 111 0
      pkg/backend/cpu/simd_avx2.s
  67. 29 0
      pkg/backend/cpu/simd_avx512.go
  68. 114 0
      pkg/backend/cpu/simd_avx512.s
  69. 12 0
      pkg/backend/cpu/simd_noasm.go
  70. 30 0
      pkg/backend/cpu/simd_random_test.go
  71. 138 0
      pkg/backend/cpu/simd_test.go
  72. 277 0
      pkg/backend/cpu/tensor.go
  73. 36 0
      pkg/backend/cpu/threads.go
  74. 141 0
      pkg/backend/cuda/attention_bench_test.go
  75. 1566 0
      pkg/backend/cuda/cuda.go
  76. 43 0
      pkg/backend/cuda/cuda_common.cuh
  77. 129 0
      pkg/backend/cuda/cuda_dequant_other.cu
  78. 50 0
      pkg/backend/cuda/cuda_dequant_q4k.cu
  79. 54 0
      pkg/backend/cuda/cuda_dequant_q5k.cu
  80. 24 0
      pkg/backend/cuda/cuda_dequant_q8k.cu
  81. 557 0
      pkg/backend/cuda/cuda_elementwise.cu
  82. 254 0
      pkg/backend/cuda/cuda_kernels_test.go
  83. 1295 0
      pkg/backend/cuda/cuda_matmul.cu
  84. 317 0
      pkg/backend/cuda/cuda_memory.cu
  85. 2165 0
      pkg/backend/cuda/cuda_nn.cu
  86. 221 0
      pkg/backend/cuda/cuda_stub.go
  87. 527 0
      pkg/backend/cuda/dequant_test.go
  88. 8 0
      pkg/backend/cuda/kernels.cu
  89. 344 0
      pkg/backend/cuda/kernels.h
  90. 254 0
      pkg/backend/device/device.go
  91. 65 0
      pkg/chat/parse_tool_calls.go
  92. 38 0
      pkg/chat/parse_tool_calls_test.go
  93. 51 0
      pkg/chat/registry.go
  94. 140 0
      pkg/chat/render_qwen3.go
  95. 70 0
      pkg/chat/render_qwen3_test.go
  96. 31 0
      pkg/chat/template.go
  97. 37 0
      pkg/chat/types.go
  98. 184 0
      pkg/compute/activation.go
  99. 122 0
      pkg/compute/compute.go
  100. 145 0
      pkg/compute/compute_test.go

+ 33 - 0
.gitignore

@@ -0,0 +1,33 @@
+# Virtual environment
+.venv/
+
+# Build artifacts
+bin/
+build/
+*.so
+*.o
+*.a
+
+# IDE
+.idea/
+.vscode/
+*.swp
+*.swo
+
+# OS files
+.DS_Store
+Thumbs.db
+
+# Test artifacts
+*.test
+coverage.out
+
+# Large model files
+*.mak
+*.gguf
+*.safetensors
+*.bin
+tokenizer.json
+.cache
+.gocache/
+.gotmp/

+ 154 - 0
Makefile

@@ -0,0 +1,154 @@
+.PHONY: all test test-gen clean build build-cuda cuda-lib quantize
+
+# CUDA Configuration
+CUDA_HOME ?= /usr/local/cuda
+NVCC ?= $(CUDA_HOME)/bin/nvcc
+CUDA_LIB_PATH ?= $(CUDA_HOME)/lib64
+
+GO_BUILD_FLAGS ?= -trimpath -ldflags="-s -w"
+
+# Build directories
+CUDA_SRC_DIR = pkg/backend/cuda
+CUDA_BUILD_DIR = build/cuda
+CUDA_OBJ = \
+	$(CUDA_BUILD_DIR)/cuda_memory.o \
+	$(CUDA_BUILD_DIR)/cuda_elementwise.o \
+	$(CUDA_BUILD_DIR)/cuda_dequant_q8k.o \
+	$(CUDA_BUILD_DIR)/cuda_dequant_q4k.o \
+	$(CUDA_BUILD_DIR)/cuda_dequant_q5k.o \
+	$(CUDA_BUILD_DIR)/cuda_dequant_other.o \
+	$(CUDA_BUILD_DIR)/cuda_matmul.o \
+	$(CUDA_BUILD_DIR)/cuda_nn.o
+CUDA_STATIC_LIB = $(CUDA_BUILD_DIR)/libmakarna_cuda.a
+CUDA_SHARED_LIB = $(CUDA_BUILD_DIR)/libmakarna_cuda.so
+
+all: build
+
+# Build CPU-only binaries
+build:
+	go build $(GO_BUILD_FLAGS) -o bin/makarna ./cmd/run-model
+	go build $(GO_BUILD_FLAGS) -o bin/quantize ./cmd/quantize
+	go build $(GO_BUILD_FLAGS) -o bin/convert ./cmd/convert
+	
+
+# Build CUDA-enabled binaries with static linking of our code
+build-cuda: cuda-static-lib
+	CGO_LDFLAGS="-L$(CURDIR)/$(CUDA_BUILD_DIR) -L$(CUDA_LIB_PATH) -Wl,-Bstatic -lmakarna_cuda -Wl,-Bdynamic -lcudart -lstdc++" \
+	CGO_CFLAGS="-I$(CURDIR)/$(CUDA_SRC_DIR)" \
+	go build $(GO_BUILD_FLAGS) -tags cuda -o bin/makarna-cuda ./cmd/run-model
+	go build $(GO_BUILD_FLAGS) -o bin/quantize ./cmd/quantize
+	go build $(GO_BUILD_FLAGS) -o bin/convert ./cmd/convert
+	@echo "CUDA build complete. Run with: ./bin/makarna-cuda"
+
+# Compile CUDA kernels into static library
+cuda-static-lib: $(CUDA_STATIC_LIB)
+
+$(CUDA_STATIC_LIB): $(CUDA_OBJ)
+	@echo "Creating static library..."
+	ar rcs $@ $^
+	@echo "Static library built: $@"
+
+
+$(CUDA_BUILD_DIR):
+	mkdir -p $@
+
+$(CUDA_BUILD_DIR)/%.o: $(CUDA_SRC_DIR)/%.cu $(CUDA_SRC_DIR)/kernels.h $(CUDA_SRC_DIR)/cuda_common.cuh | $(CUDA_BUILD_DIR)
+	@echo "Compiling CUDA kernels..."
+	$(NVCC) -c -Xcompiler -fPIC -Xcompiler -O3 -Xcompiler -DNDEBUG \
+		-O3 \
+		--use_fast_math \
+		--expt-relaxed-constexpr \
+		-std=c++17 \
+		-arch=sm_75 \
+		-gencode=arch=compute_75,code=sm_75 \
+		-gencode=arch=compute_80,code=sm_80 \
+		-gencode=arch=compute_86,code=sm_86 \
+		-gencode=arch=compute_89,code=sm_89 \
+		-o $@ $<
+
+# Legacy: shared library (kept for compatibility)
+cuda-lib: $(CUDA_OBJ)
+	@echo "Building CUDA shared library..."
+	$(NVCC) -shared -Xcompiler -fPIC \
+		-O3 \
+		--use_fast_math \
+		-arch=sm_75 \
+		-o $(CUDA_SHARED_LIB) $(CUDA_OBJ)
+run-cuda: build-cuda
+	LD_LIBRARY_PATH=$(CURDIR)/$(CUDA_BUILD_DIR):$(CUDA_LIB_PATH):$$LD_LIBRARY_PATH \
+	./bin/makarna-cuda -model $(MODEL) -prompt "$(PROMPT)" -chat -steps $(STEPS) -n-gpu-layers $(GPU_LAYERS)
+
+# Default values for run-cuda
+MODEL ?= /home/ai/llama/quants/qwen3-q8.mak
+PROMPT ?= "Hello"
+STEPS ?= 10
+GPU_LAYERS ?= 28
+
+PYTHON ?= python3
+
+test-gen:
+	@echo "Generating golden test data..."
+	PYTHONPATH=. $(PYTHON) scripts/gen_test_data.py
+	@echo "Running tests..."
+	go test -v ./tests/... ./pkg/...
+
+test-cpu:
+	@echo "Running CPU tests..."
+	go test -v ./pkg/...
+
+test-cuda: cuda-lib
+	@echo "Running CUDA tests..."
+	CGO_LDFLAGS="-L$(CURDIR)/$(CUDA_BUILD_DIR) -L$(CUDA_LIB_PATH) -Wl,-Bstatic -lmakarna_cuda -Wl,-Bdynamic -lcudart -Wl,-rpath,$(CURDIR)/$(CUDA_BUILD_DIR) -Wl,-rpath,$(CUDA_LIB_PATH)" \
+	LD_LIBRARY_PATH=$(CURDIR)/$(CUDA_BUILD_DIR):$(CUDA_LIB_PATH):$$LD_LIBRARY_PATH \
+	go test -tags cuda -v ./pkg/backend/cuda/...
+
+test-quant:
+	@echo "Testing quantization functions..."
+	go test -v ./pkg/quant/...
+
+bench-quant:
+	@echo "Benchmarking quantization..."
+	go test -bench=. ./pkg/quant/
+
+clean:
+	rm -rf bin/
+	rm -rf build/
+	rm -f tests/data/*.bin
+	rm -f $(CUDA_LIB)
+
+clean-cuda:
+	rm -f $(CUDA_LIB)
+
+# Convenience targets for model conversion
+convert-f32:
+	PYTHONPATH=scripts $(PYTHON) scripts/convert_fast.py $(MODEL) $(OUTPUT)
+
+quantize-q4k:
+	./bin/quantize $(INPUT) $(OUTPUT) q4_k
+
+quantize-q6k:
+	./bin/quantize $(INPUT) $(OUTPUT) q6_k
+
+quantize-q8k:
+	./bin/quantize $(INPUT) $(OUTPUT) q8_k
+
+# Help
+help:
+	@echo "Makarna - Inference Engine"
+	@echo ""
+	@echo "Build targets:"
+	@echo "  make build        - Build CPU-only binaries"
+	@echo "  make build-cuda   - Build CUDA-enabled binaries"
+	@echo "  make cuda-lib     - Build CUDA kernel library only"
+	@echo ""
+	@echo "Run targets:"
+	@echo "  make run-cuda MODEL=path PROMPT='text' STEPS=n GPU_LAYERS=n"
+	@echo ""
+	@echo "Test targets:"
+	@echo "  make test-cpu     - Run CPU tests"
+	@echo "  make test-cuda    - Run CUDA tests"
+	@echo "  make test-quant   - Run quantization tests"
+	@echo ""
+	@echo "Clean targets:"
+	@echo "  make clean        - Remove all build artifacts"
+	@echo "  make clean-cuda   - Remove CUDA library only"

+ 115 - 0
README.md

@@ -0,0 +1,115 @@
+# Experimental Project (Archived)
+This is an experimental project and is no longer maintained.
+
+The implementation and algorithms are heavily inspired by [llama.cpp](https://github.com/ggerganov/llama.cpp) and [vLLM](https://github.com/vllm-project/vllm).
+
+# Makarna Engine
+High-performance LLM inference engine in Go, optimized with SIMD (AVX2/AVX512).
+
+## Installation
+
+Build with Makefile:
+```bash
+make build
+```
+This produces binaries in `bin/`: `makarna`, `quantize`, `convert`.
+
+Build with CUDA:
+```bash
+make build-cuda
+```
+Produces `bin/makarna-cuda`.
+
+Alternatively, use Go install:
+```bash
+go install ./cmd/...
+```
+
+## Commands
+
+### convert
+Convert HuggingFace models (.safetensors) to .mak format.
+```bash
+convert <hf_dir> <output.mak> [flags]
+```
+Flags:
+- `--quant <type>` Options: q2_k, q3_k, q4_k, q5_k, q6_k, q8_k.
+- `--mix` Enable smart mix quantization.
+- `--workers <n>` Number of parallel workers.
+- `--max-inflight-mb <n>` Memory limit during conversion.
+
+### quantize
+Quantize an existing .mak file to a K-quant format.
+```bash
+quantize <input.mak> <output.mak> <type> [flags]
+```
+Flags:
+- `--mix` Enable smart mix mode.
+
+### run-model
+Inference CLI.
+```bash
+run-model -model <file.mak> -prompt "text" [flags]
+```
+Common Flags:
+- `-steps <n>` Max tokens (default 10).
+- `-temp <f>` Temperature (default 0.7).
+- `-top-k <n>` Top-K (default 40).
+- `-top-p <f>` Top-P (default 0.9).
+- `-rep-penalty <f>` Repetition penalty (default 1.1).
+- `-chat` Use chat formatting.
+- `-threads <n>` CPU threads (-1 = 90% of cores).
+- `-n-gpu-layers <n>` Layers to offload to GPU (-1=auto).
+- `-gpu-budget <f>` GPU memory fraction (0.0-1.0).
+- `-mmap` Use mmap for weights.
+- `-profile-log <val>` Profile output (true, report, or <file>).
+- `-listen <addr>` Start OpenAI-compatible server on <addr>.
+
+### openai
+Dedicated OpenAI-compatible API server.
+```bash
+openai -model <file.mak> [flags]
+```
+Flags:
+- `-listen <addr>` Default is :8080.
+- `-max-seq-len <n>` Max context length.
+- `-n-gpu-layers <n>` Number of GPU layers.
+
+## Quantization Types
+MAK v2 supports K-quants (block size 256):
+- `q8_k`: 8-bit.
+- `q6_k`: 6-bit.
+- `q5_k`: 5-bit.
+- `q4_k`: 4-bit (recommended).
+- `q3_k`: 3-bit.
+- `q2_k`: 2-bit.
+
+## Examples
+
+Convert and quantize:
+```bash
+convert /models/Qwen3-1.7B-Instruct model-q4k.mak --quant q4_k --mix
+```
+
+Run inference:
+```bash
+run-model -model model-q4k.mak -prompt "Explaining quantum physics" -steps 100
+```
+
+Start API server:
+```bash
+run-model -model model-q4k.mak -listen :8080 -chat
+```
+
+## Development
+
+Tests:
+```bash
+go test ./...
+go test -tags cuda ./... # Requires GPU
+```
+
+Benchmarks:
+```bash
+go test -bench=. ./pkg/tensor/...
+```

+ 107 - 0
cmd/convert/main.go

@@ -0,0 +1,107 @@
+// Command convert converts HuggingFace models (safetensors) to .mak format
+//
+// Pure Go implementation - no Python dependencies!
+//
+// Usage:
+//
+//	convert /path/to/model output.mak
+//	convert /path/to/model output.mak --quant q4_k
+//	convert /path/to/model output.mak --quant q4_k --mix
+package main
+
+import (
+	"flag"
+	"fmt"
+	"os"
+	"strings"
+
+	"makarna/pkg/convert"
+	_ "makarna/pkg/model/models"
+	"makarna/pkg/quant"
+)
+
+func main() {
+	fs := flag.NewFlagSet(os.Args[0], flag.ExitOnError)
+	quantType := fs.String("quant", "", "Quantization type: q2_k, q3_k, q4_k, q5_k, q6_k, q8_k (empty = F32)")
+	mixMode := fs.Bool("mix", false, "Enable smart mix quantization")
+	workers := fs.Int("workers", 0, "Number of parallel workers for conversion (0 = GOMAXPROCS)")
+	maxInFlightMB := fs.Int("max-inflight-mb", 0, "Max in-flight tensor memory (decoded+output) in MB (0 = auto)")
+
+	// Go's standard flag package stops parsing at the first non-flag argument.
+	// Users often pass: convert <model_dir> <out.mak> --quant q4_k --mix
+	// Pre-scan to allow flags to appear anywhere.
+	flagArgs, posArgs := splitFlagsAndPositionals(os.Args[1:])
+	_ = fs.Parse(flagArgs)
+
+	args := posArgs
+	if len(args) < 2 {
+		fmt.Println("Usage: convert <model_dir> <output.mak> [--quant TYPE] [--mix]")
+		os.Exit(1)
+	}
+
+	modelPath := args[0]
+	outputPath := args[1]
+
+	var baseQuant quant.QuantType
+	if *quantType != "" {
+		baseQuant = quant.QuantType(*quantType)
+		switch baseQuant {
+		case quant.TypeQ2K, quant.TypeQ3K, quant.TypeQ4K, quant.TypeQ5K, quant.TypeQ6K, quant.TypeQ8K:
+			// OK
+		default:
+			fmt.Printf("Unknown quant type: %s\n", *quantType)
+			os.Exit(1)
+		}
+	}
+
+	fmt.Printf("convert: model=%s out=%s quant=%s mix=%v workers=%d max_inflight_mb=%d\n", modelPath, outputPath, baseQuant, *mixMode, *workers, *maxInFlightMB)
+
+	opts := convert.Options{
+		BaseQuant:        baseQuant,
+		MixMode:          *mixMode,
+		Workers:          *workers,
+		MaxInFlightBytes: uint64(*maxInFlightMB) * 1024 * 1024,
+	}
+	if err := convert.ConvertDirectory(modelPath, outputPath, opts); err != nil {
+		fmt.Printf("convert failed: %v\n", err)
+		os.Exit(1)
+	}
+}
+
+func splitFlagsAndPositionals(argv []string) (flagArgs []string, posArgs []string) {
+	// Known flags for this command. Values indicate whether the flag expects a separate value.
+	expectsValue := map[string]bool{
+		"--quant":           true,
+		"-quant":            true,
+		"--workers":         true,
+		"-workers":          true,
+		"--max-inflight-mb": true,
+		"-max-inflight-mb":  true,
+		"--mix":             false,
+		"-mix":              false,
+	}
+
+	for i := 0; i < len(argv); i++ {
+		a := argv[i]
+		if !strings.HasPrefix(a, "-") {
+			posArgs = append(posArgs, a)
+			continue
+		}
+
+		// Support --flag=value form.
+		if strings.Contains(a, "=") {
+			flagArgs = append(flagArgs, a)
+			continue
+		}
+
+		flagArgs = append(flagArgs, a)
+		if expectsValue[a] {
+			if i+1 < len(argv) && !strings.HasPrefix(argv[i+1], "-") {
+				flagArgs = append(flagArgs, argv[i+1])
+				i++
+			}
+		}
+	}
+
+	return flagArgs, posArgs
+}

+ 440 - 0
cmd/openai/main.go

@@ -0,0 +1,440 @@
+package main
+
+import (
+	"context"
+	"encoding/json"
+	"flag"
+	"fmt"
+	"log"
+	"math/rand"
+	"net/http"
+	"strings"
+	"time"
+	"unsafe"
+
+	"makarna/pkg/backend/cpu"
+	"makarna/pkg/backend/cuda"
+	"makarna/pkg/backend/device"
+	"makarna/pkg/chat"
+	"makarna/pkg/engine"
+	"makarna/pkg/kvcache"
+	"makarna/pkg/sample"
+	"makarna/pkg/tensor"
+	"makarna/pkg/tokenizer"
+)
+
+type chatCompletionRequest struct {
+	Model            string                  `json:"model"`
+	Messages         []chatCompletionMessage `json:"messages"`
+	Tools            []any                   `json:"tools,omitempty"`
+	Stream           bool                    `json:"stream,omitempty"`
+	MaxTokens        int                     `json:"max_tokens,omitempty"`
+	Temperature      float64                 `json:"temperature,omitempty"`
+	TopP             float64                 `json:"top_p,omitempty"`
+	TopK             int                     `json:"top_k,omitempty"`
+	PresencePenalty  float64                 `json:"presence_penalty,omitempty"`
+	FrequencyPenalty float64                 `json:"frequency_penalty,omitempty"`
+}
+
+type chatCompletionMessage struct {
+	Role       string `json:"role"`
+	Content    string `json:"content"`
+	Name       string `json:"name,omitempty"`
+	ToolCallID string `json:"tool_call_id,omitempty"`
+}
+
+type chatCompletionResponse struct {
+	ID      string                 `json:"id"`
+	Object  string                 `json:"object"`
+	Created int64                  `json:"created"`
+	Model   string                 `json:"model"`
+	Usage   chatCompletionUsage    `json:"usage"`
+	Choices []chatCompletionChoice `json:"choices"`
+}
+
+type chatCompletionUsage struct {
+	PromptTokens     int `json:"prompt_tokens"`
+	CompletionTokens int `json:"completion_tokens"`
+	TotalTokens      int `json:"total_tokens"`
+}
+
+type chatCompletionChoice struct {
+	Index        int                  `json:"index"`
+	Message      chatCompletionOutMsg `json:"message"`
+	FinishReason string               `json:"finish_reason"`
+}
+
+type chatCompletionOutMsg struct {
+	Role      string           `json:"role"`
+	Content   string           `json:"content"`
+	ToolCalls []openAIToolCall `json:"tool_calls,omitempty"`
+}
+
+type openAIToolCall struct {
+	ID       string             `json:"id"`
+	Type     string             `json:"type"`
+	Function openAIFunctionCall `json:"function"`
+}
+
+type openAIFunctionCall struct {
+	Name      string `json:"name"`
+	Arguments string `json:"arguments"`
+}
+
+type server struct {
+	eng       *engine.Engine
+	tok       *tokenizer.Tokenizer
+	arch      string
+	maxSeqLen int
+	blockSize int
+}
+
+func main() {
+	listen := flag.String("listen", ":8080", "listen address")
+	modelPath := flag.String("model", "model.mak", "Path to .mak model file")
+	maxSeq := flag.Int("max-seq-len", 8192, "Maximum sequence length to reserve in KV cache")
+	blockSize := flag.Int("block-size", 32, "KV cache block size")
+
+	nGPULayers := flag.Int("n-gpu-layers", -1, "Number of layers to offload to GPU (-1=auto, 0=CPU only)")
+	gpuBudget := flag.Float64("gpu-budget", 0.9, "Fraction of GPU memory to use (0.0-1.0)")
+	flag.Parse()
+
+	cfg := engine.Config{GPULayers: *nGPULayers, GPUBudget: *gpuBudget}
+	eng, err := engine.Load(*modelPath, cfg)
+	if err != nil {
+		log.Fatalf("load model: %v", err)
+	}
+	defer eng.Close()
+
+	md := eng.Model().Config()
+
+	var tok *tokenizer.Tokenizer
+	tokData, err := eng.Loader().GetTokenizerData()
+	if err == nil && len(tokData) > 0 {
+		tok, err = tokenizer.LoadFromBytes(tokData)
+		if err != nil {
+			log.Printf("warning: load embedded tokenizer: %v", err)
+		}
+	}
+	if tok == nil {
+		log.Fatalf("tokenizer not found in model file")
+	}
+
+	s := &server{eng: eng, tok: tok, arch: md.Architecture, maxSeqLen: *maxSeq, blockSize: *blockSize}
+
+	h := http.NewServeMux()
+	h.HandleFunc("/v1/chat/completions", s.handleChatCompletions)
+	h.HandleFunc("/v1/models", s.handleModels)
+
+	log.Printf("listening on %s (arch=%s, cuda=%v)", *listen, s.arch, device.CUDAAvailable())
+	log.Fatal(http.ListenAndServe(*listen, h))
+}
+
+func (s *server) handleModels(w http.ResponseWriter, r *http.Request) {
+	if r.Method != http.MethodGet {
+		http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
+		return
+	}
+	resp := map[string]any{
+		"object": "list",
+		"data":   []any{map[string]any{"id": "local", "object": "model"}},
+	}
+	writeJSON(w, resp)
+}
+
+func (s *server) handleChatCompletions(w http.ResponseWriter, r *http.Request) {
+	if r.Method != http.MethodPost {
+		http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
+		return
+	}
+	var req chatCompletionRequest
+	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+		http.Error(w, "bad json", http.StatusBadRequest)
+		return
+	}
+	if req.Stream {
+		http.Error(w, "stream not implemented", http.StatusNotImplemented)
+		return
+	}
+	if len(req.Messages) == 0 {
+		http.Error(w, "messages required", http.StatusBadRequest)
+		return
+	}
+
+	msgs := make([]chat.Message, 0, len(req.Messages))
+	for _, m := range req.Messages {
+		role := strings.ToLower(m.Role)
+		msgs = append(msgs, chat.Message{Role: role, Content: m.Content})
+	}
+
+	prompt, err := chat.RenderForArchitecture(s.arch, msgs, chat.Options{
+		AddGenerationPrompt: true,
+		EnableThinking:      true,
+		Tools:               req.Tools,
+	})
+	if err != nil {
+		http.Error(w, fmt.Sprintf("render prompt: %v", err), http.StatusInternalServerError)
+		return
+	}
+
+	promptTokens := len(s.tok.Encode(prompt))
+
+	maxTokens := req.MaxTokens
+	if maxTokens <= 0 {
+		maxTokens = 128
+	}
+	temp := req.Temperature
+	if temp == 0 {
+		temp = 0.7
+	}
+	topP := req.TopP
+	if topP == 0 {
+		topP = 0.9
+	}
+	topK := req.TopK
+	if topK == 0 {
+		topK = 40
+	}
+
+	outText, err := s.generate(r.Context(), prompt, maxTokens, temp, topP, topK)
+	if err != nil {
+		http.Error(w, fmt.Sprintf("generate: %v", err), http.StatusInternalServerError)
+		return
+	}
+	completionTokens := len(s.tok.Encode(outText))
+
+	_, content := chat.StripThinking(outText)
+	content, calls, err := chat.ExtractToolCalls(content)
+	if err != nil {
+		http.Error(w, fmt.Sprintf("parse tool_calls: %v", err), http.StatusInternalServerError)
+		return
+	}
+	content = strings.TrimSpace(content)
+
+	var outCalls []openAIToolCall
+	for i, c := range calls {
+		outCalls = append(outCalls, openAIToolCall{
+			ID:   fmt.Sprintf("call_%d_%d", time.Now().UnixNano(), i),
+			Type: "function",
+			Function: openAIFunctionCall{
+				Name:      c.Name,
+				Arguments: string(bytesOrEmptyObject(c.Arguments)),
+			},
+		})
+	}
+
+	finish := "stop"
+	if len(outCalls) > 0 {
+		finish = "tool_calls"
+	}
+
+	resp := chatCompletionResponse{
+		ID:      fmt.Sprintf("chatcmpl_%d", rand.Int63()),
+		Object:  "chat.completion",
+		Created: time.Now().Unix(),
+		Model:   req.Model,
+		Usage: chatCompletionUsage{
+			PromptTokens:     promptTokens,
+			CompletionTokens: completionTokens,
+			TotalTokens:      promptTokens + completionTokens,
+		},
+		Choices: []chatCompletionChoice{{
+			Index: 0,
+			Message: chatCompletionOutMsg{
+				Role:      "assistant",
+				Content:   content,
+				ToolCalls: outCalls,
+			},
+			FinishReason: finish,
+		}},
+	}
+	writeJSON(w, resp)
+}
+
+func bytesOrEmptyObject(b []byte) []byte {
+	if len(b) == 0 {
+		return []byte("{}")
+	}
+	return b
+}
+
+func writeJSON(w http.ResponseWriter, v any) {
+	w.Header().Set("Content-Type", "application/json")
+	enc := json.NewEncoder(w)
+	enc.SetEscapeHTML(false)
+	_ = enc.Encode(v)
+}
+
+func (s *server) generate(ctx context.Context, prompt string, maxTokens int, temperature float64, topP float64, topK int) (string, error) {
+	ids := s.tok.Encode(prompt)
+	if len(ids) == 0 {
+		return "", fmt.Errorf("empty prompt after tokenization")
+	}
+
+	modelCfg := s.eng.Model().Config()
+	placements := make([]tensor.DevicePlacement, modelCfg.NumLayers)
+	if s.eng.Dispatcher() != nil {
+		for i := 0; i < modelCfg.NumLayers; i++ {
+			placements[i] = s.eng.Dispatcher().LayerPlacement(i)
+		}
+	}
+
+	// Enable mixed per-layer KV cache when any layer is on GPU.
+	kvDevice := tensor.CPU
+	if device.CUDAAvailable() {
+		for i := 0; i < modelCfg.NumLayers && i < len(placements); i++ {
+			if placements[i].Normalize().Type == tensor.CUDA {
+				kvDevice = tensor.CUDA
+				break
+			}
+		}
+	}
+	pool, err := kvcache.NewBlockPool(kvcache.BlockPoolConfig{
+		NumLayers:  modelCfg.NumLayers,
+		NumKVHeads: modelCfg.NumKVHeads,
+		HeadDim:    modelCfg.HeadDim,
+		BlockSize:  s.blockSize,
+		NumBlocks:  (s.maxSeqLen + s.blockSize - 1) / s.blockSize,
+		Device:     kvDevice,
+		GPU:        0,
+		LayerPlacements: func() []tensor.DevicePlacement {
+			if kvDevice != tensor.CUDA || len(placements) != modelCfg.NumLayers {
+				return nil
+			}
+			out := make([]tensor.DevicePlacement, modelCfg.NumLayers)
+			for i := 0; i < modelCfg.NumLayers; i++ {
+				out[i] = placements[i].Normalize()
+			}
+			return out
+		}(),
+		Preallocate: kvDevice == tensor.CUDA,
+	})
+	if err != nil {
+		return "", err
+	}
+	cache := kvcache.NewPagedKVCache(pool, kvcache.PagedCacheConfig{
+		NumLayers:  modelCfg.NumLayers,
+		NumKVHeads: modelCfg.NumKVHeads,
+		HeadDim:    modelCfg.HeadDim,
+		BlockSize:  s.blockSize,
+		MaxSeqLen:  s.maxSeqLen,
+		Device:     kvDevice,
+		GPU:        0,
+	}, "cmd-openai")
+	if _, err := cache.AllocateForTokens(ids); err != nil {
+		cache.Free()
+		return "", err
+	}
+	defer cache.Free()
+
+	sampler := sample.New(sample.Config{
+		Temperature:       float32(temperature),
+		TopK:              topK,
+		TopP:              float32(topP),
+		RepetitionPenalty: 1.1,
+		Seed:              -1,
+	})
+
+	input := createInputTensor(ids)
+	positions := createPositionTensor(0, len(ids))
+	logits, err := s.eng.Forward(ctx, input, positions, cache)
+	if err != nil {
+		return "", err
+	}
+
+	// sample first token
+	var nextToken int
+	if logitsCPU := getLogitsRowCPU(logits, len(ids)-1); logitsCPU != nil {
+		nextToken = sampler.Sample(logitsCPU, ids)
+	} else {
+		gpuLogits := logits.(*cuda.Tensor)
+		vocabSize := gpuLogits.Shape()[1]
+		row := len(ids) - 1
+		view, err := gpuLogits.ViewAt(tensor.Shape{vocabSize}, uintptr(row*vocabSize*4))
+		if err != nil {
+			return "", err
+		}
+		host := make([]float32, vocabSize)
+		if err := view.CopyToHost(host); err != nil {
+			return "", err
+		}
+		nextToken = sampler.Sample(host, ids)
+	}
+
+	ids = append(ids, nextToken)
+	var sb strings.Builder
+	sb.WriteString(s.tok.Decode([]int{nextToken}))
+
+	eosID := s.tok.EosID()
+	for i := 1; i < maxTokens; i++ {
+		if nextToken == eosID {
+			break
+		}
+		select {
+		case <-ctx.Done():
+			return "", ctx.Err()
+		default:
+		}
+		input = createInputTensor([]int{nextToken})
+		currentPos := len(ids) - 1
+		positions = createPositionTensor(currentPos, 1)
+		logits, err = s.eng.Forward(ctx, input, positions, cache)
+		if err != nil {
+			return "", err
+		}
+
+		recent := ids
+		if len(recent) > 64 {
+			recent = recent[len(recent)-64:]
+		}
+		if logitsCPU := getLogitsRowCPU(logits, 0); logitsCPU != nil {
+			nextToken = sampler.Sample(logitsCPU, recent)
+		} else {
+			gpuLogits := logits.(*cuda.Tensor)
+			vocabSize := gpuLogits.Shape()[1]
+			view, err := gpuLogits.ViewAt(tensor.Shape{vocabSize}, 0)
+			if err != nil {
+				return "", err
+			}
+			host := make([]float32, vocabSize)
+			if err := view.CopyToHost(host); err != nil {
+				return "", err
+			}
+			nextToken = sampler.Sample(host, recent)
+		}
+
+		ids = append(ids, nextToken)
+		sb.WriteString(s.tok.Decode([]int{nextToken}))
+	}
+
+	return sb.String(), nil
+}
+
+func createInputTensor(ids []int) tensor.Tensor {
+	t := cpu.NewTensor(tensor.Shape{len(ids)}, nil)
+	data := t.DataFloat32()
+	for i, id := range ids {
+		data[i] = float32(id)
+	}
+	return t
+}
+
+func createPositionTensor(start, count int) tensor.Tensor {
+	t := cpu.NewTensor(tensor.Shape{count}, nil)
+	data := t.DataFloat32()
+	for i := 0; i < count; i++ {
+		data[i] = float32(start + i)
+	}
+	return t
+}
+
+func getLogitsRowCPU(logits tensor.Tensor, row int) []float32 {
+	if _, ok := logits.(*cpu.Tensor); !ok {
+		return nil
+	}
+	data := logits.Data().(unsafe.Pointer)
+	shape := logits.Shape()
+	vocabSize := shape[1]
+	slice := unsafe.Slice((*float32)(data), shape.NumElements())
+	return slice[row*vocabSize : (row+1)*vocabSize]
+}

+ 212 - 0
cmd/quantize/main.go

@@ -0,0 +1,212 @@
+// Command quantize converts F32/F16 .mak files to quantized versions
+//
+// Usage:
+//
+//	quantize input.mak output.mak q4_k
+//	quantize input.mak output.mak q4_k --mix  (enables smart mix quantization)
+package main
+
+import (
+	"encoding/binary"
+	"flag"
+	"fmt"
+	"math"
+	"os"
+	"time"
+
+	"makarna/pkg/convert"
+	"makarna/pkg/loader"
+	"makarna/pkg/quant"
+	_ "makarna/pkg/model/models"
+)
+
+func main() {
+	// Parse flags
+	mixMode := flag.Bool("mix", false, "Enable smart mix quantization (uses architecture-specific rules)")
+	flag.Parse()
+
+	args := flag.Args()
+	if len(args) < 3 {
+		fmt.Println("Usage: quantize <input.mak> <output.mak> <quant_type> [--mix]")
+		fmt.Println("Quant types: q2_k, q3_k, q4_k, q5_k, q6_k, q8_k")
+		fmt.Println("Flags:")
+		fmt.Println("  --mix    Enable smart mix quantization (uses architecture-specific rules)")
+		os.Exit(1)
+	}
+
+	inputPath := args[0]
+	outputPath := args[1]
+	quantTypeStr := args[2]
+
+	baseQuant := quant.QuantType(quantTypeStr)
+	
+	// Validate quant type
+	switch baseQuant {
+	case quant.TypeQ2K, quant.TypeQ3K, quant.TypeQ4K, quant.TypeQ5K, quant.TypeQ6K, quant.TypeQ8K:
+		// OK
+	default:
+		fmt.Printf("Unknown quant type: %s\n", quantTypeStr)
+		os.Exit(1)
+	}
+
+	fmt.Printf("Loading %s...\n", inputPath)
+	startTime := time.Now()
+
+	model, err := loader.Load(inputPath)
+	if err != nil {
+		fmt.Printf("Error loading model: %v\n", err)
+		os.Exit(1)
+	}
+	defer model.Close()
+
+	architecture := model.Metadata.ModelConfig.Architecture
+	fmt.Printf("Loaded in %v\n", time.Since(startTime))
+	fmt.Printf("Architecture: %s\n", architecture)
+	
+	// Build a policy spec (model plugin may override mix behavior)
+	tieWordEmbeddings := false
+	if tie, ok := model.Metadata.ModelConfig.Params["tie_word_embeddings"].(bool); ok {
+		tieWordEmbeddings = tie
+	}
+	spec := convert.NewSpec(architecture, tieWordEmbeddings, baseQuant, *mixMode)
+
+	// Create output writer
+	writer, err := loader.NewWriter(outputPath)
+	if err != nil {
+		fmt.Printf("Error creating output: %v\n", err)
+		os.Exit(1)
+	}
+
+	writer.SetModelConfig(model.Metadata.ModelConfig)
+
+	// Copy tokenizer if present
+	tokData, err := model.GetTokenizerData()
+	if err == nil && len(tokData) > 0 {
+		writer.AddTokenizer(tokData)
+		fmt.Println("Copying embedded tokenizer...")
+	}
+
+	// Process tensors
+	totalTensors := len(model.Metadata.Tensors)
+	stats := make(map[quant.QuantType]int)
+	skipped := 0
+
+	fmt.Printf("\nQuantizing %d tensors...\n\n", totalTensors)
+
+	for name, info := range model.Metadata.Tensors {
+		data, err := model.GetTensorData(name)
+		if err != nil {
+			fmt.Printf("Error reading tensor %s: %v\n", name, err)
+			continue
+		}
+
+		// Determine if quantizable
+		nDims := len(info.Shape)
+		isQuantizable := nDims >= 2 && info.DType == loader.F32
+		
+		// Check divisibility by 256
+		if isQuantizable && info.Shape[len(info.Shape)-1]%256 != 0 {
+			isQuantizable = false
+		}
+
+		var outData []byte
+		var outDType loader.DType
+		
+		if isQuantizable {
+			// Resolve quant type (with mix mode if enabled)
+			tensorQuant := baseQuant
+			tensorQuant = spec.ResolveQuant(name, baseQuant)
+			
+			// Handle F32 (keep as-is)
+			if tensorQuant == quant.TypeF32 || tensorQuant == quant.TypeF16 {
+				outData = data
+				outDType = tensorQuant.ToDType()
+				stats[tensorQuant]++
+				fmt.Printf("  %s: %v [%s] (preserved)\n", name, info.Shape, tensorQuant)
+			} else {
+				// Convert bytes to float32
+				floats := bytesToFloat32(data)
+				
+				start := time.Now()
+				
+				switch tensorQuant {
+				case quant.TypeQ8K:
+					outData = quant.QuantizeQ8K(floats)
+				case quant.TypeQ5K:
+					outData = quant.QuantizeQ5K(floats)
+				case quant.TypeQ6K:
+					outData = quant.QuantizeQ6K(floats)
+				case quant.TypeQ4K:
+					outData = quant.QuantizeQ4K(floats)
+				case quant.TypeQ3K:
+					outData = quant.QuantizeQ3K(floats)
+				case quant.TypeQ2K:
+					outData = quant.QuantizeQ2K(floats)
+				default:
+					outData = quant.QuantizeQ4K(floats)
+					tensorQuant = quant.TypeQ4K
+				}
+				
+				elapsed := time.Since(start)
+				outDType = tensorQuant.ToDType()
+				stats[tensorQuant]++
+				
+				ratio := float64(len(data)) / float64(len(outData))
+				
+				// Show mix info if different from base
+				mixInfo := ""
+				if *mixMode && tensorQuant != baseQuant {
+					mixInfo = fmt.Sprintf(" (mix: %s→%s)", baseQuant, tensorQuant)
+				}
+				
+				fmt.Printf("  %s: %v → %s (%.2fx, %v)%s\n", 
+					name, info.Shape, tensorQuant, ratio, elapsed, mixInfo)
+			}
+		} else {
+			// Keep as-is
+			outData = data
+			outDType = info.DType
+			skipped++
+		}
+
+		// Convert shape to uint64
+		shape := make([]uint64, len(info.Shape))
+		for i, s := range info.Shape {
+			shape[i] = s
+		}
+
+		if err := writer.AddTensor(name, outDType, shape, outData); err != nil {
+			fmt.Printf("Error writing tensor %s: %v\n", name, err)
+		}
+	}
+
+	if err := writer.Close(); err != nil {
+		fmt.Printf("Error closing output: %v\n", err)
+		os.Exit(1)
+	}
+
+	// Get file sizes
+	inStat, _ := os.Stat(inputPath)
+	outStat, _ := os.Stat(outputPath)
+
+	fmt.Printf("\n✓ Done!\n")
+	fmt.Printf("  Quantization breakdown:\n")
+	for qt, count := range stats {
+		fmt.Printf("    %s: %d tensors\n", qt, count)
+	}
+	fmt.Printf("  Skipped: %d tensors\n", skipped)
+	fmt.Printf("  Input size: %.2f MB\n", float64(inStat.Size())/(1024*1024))
+	fmt.Printf("  Output size: %.2f MB\n", float64(outStat.Size())/(1024*1024))
+	fmt.Printf("  Compression: %.2fx\n", float64(inStat.Size())/float64(outStat.Size()))
+	fmt.Printf("  Total time: %v\n", time.Since(startTime))
+}
+
+func bytesToFloat32(data []byte) []float32 {
+	n := len(data) / 4
+	result := make([]float32, n)
+	for i := 0; i < n; i++ {
+		bits := binary.LittleEndian.Uint32(data[i*4 : i*4+4])
+		result[i] = math.Float32frombits(bits)
+	}
+	return result
+}

+ 729 - 0
cmd/run-model/main.go

@@ -0,0 +1,729 @@
+package main
+
+import (
+	"context"
+	"encoding/json"
+	"flag"
+	"fmt"
+	"log"
+	"net"
+	"path/filepath"
+	"sort"
+	"strconv"
+	"strings"
+	"time"
+	"unsafe"
+
+	"makarna/pkg/backend/cpu"
+	"makarna/pkg/backend/cuda"
+	"makarna/pkg/backend/device"
+	"makarna/pkg/chat"
+	"makarna/pkg/compute"
+	"makarna/pkg/engine"
+	"makarna/pkg/kvcache"
+	"makarna/pkg/loader"
+	"makarna/pkg/openai"
+	"makarna/pkg/profile"
+	"makarna/pkg/sample"
+	"makarna/pkg/model"
+	"makarna/pkg/tensor"
+	"makarna/pkg/tokenizer"
+
+	kimi_linear "makarna/pkg/model/models/kimi_linear" // Register KimiLinear model
+	_ "makarna/pkg/model/models/qwen3" // Register Qwen3 model
+)
+
+func main() {
+	modelPath := flag.String("model", "model.mak", "Path to .mak model file")
+	prompt := flag.String("prompt", "Hello world", "Prompt to generate")
+	steps := flag.Int("steps", 10, "Number of tokens to generate")
+	useChat := flag.Bool("chat", false, "Use chat format for prompt")
+	serverMode := flag.Bool("server", false, "Run OpenAI-compatible HTTP server")
+	listen := flag.String("listen", "", "Server listen address (e.g. :8080, 0.0.0.0:8080). If set, implies --server")
+	host := flag.String("host", "127.0.0.1", "Server host (used when --listen is empty)")
+	port := flag.Int("port", 8080, "Server port (used when --listen is empty)")
+	temperature := flag.Float64("temp", 0.7, "Sampling temperature (0 = greedy)")
+	topK := flag.Int("top-k", 40, "Top-K sampling (0 = disabled)")
+	topP := flag.Float64("top-p", 0.9, "Top-P nucleus sampling (1.0 = disabled)")
+	repPenalty := flag.Float64("rep-penalty", 1.1, "Repetition penalty (1.0 = disabled)")
+	threads := flag.Int("threads", -1, "Number of CPU threads to use (default: 90% of cores)")
+	listTensors := flag.Bool("list-tensors", false, "List tensors in model and exit")
+	useMmap := flag.Bool("mmap", false, "Use mmap for model weights (default: false)")
+
+	// Device placement flags - llama.cpp style
+	nGPULayers := flag.Int("n-gpu-layers", -1, "Number of layers to offload to GPU (-1=auto, 0=CPU only)")
+	gpuBudget := flag.Float64("gpu-budget", 0.9, "Fraction of GPU memory to use (0.0-1.0)")
+	gpuDevicesFlag := flag.String("gpu-devices", "", "Comma-separated GPU device ordinals to use (e.g. 0 or 0,1)")
+	layerMap := flag.String("layer-map", "", "Advanced: layer placement map, e.g. 0-9:gpu0,10-19:gpu1,20-:cpu")
+	cpuMoE := flag.Bool("cpu-moe", false, "Keep MoE expert weights on CPU (saves GPU memory for large MoE models)")
+
+	blockSize := flag.Int("block-size", 16, "KV cache block size")
+	maxSeq := flag.Int("max-seq-len", 2048, "Maximum sequence length to reserve in KV cache")
+	kvCacheCPU := flag.Bool("kv-cache-cpu", false, "Force KV cache to CPU (default: GPU when available)")
+	prefillChunkSize := flag.Int("prefill-chunk-size", 512, "Prompt prefill chunk size (llama.cpp eval batch size analogue)")
+	maxConcurrent := flag.Int("max-concurrent", 1, "Server mode: max concurrent sequences to reserve KV/scratch for")
+
+	// Profiling flags
+	profileOn := flag.Bool("profile", false, "Enable profiling summary report (alias for -profile-log=report)")
+	profileLog := flag.String("profile-log", "", "Enable profiling: 'true'=realtime screen, 'report'=summary only, or file path")
+
+	flag.Parse()
+
+	// Initialize profiling
+	if *profileOn && *profileLog == "" {
+		*profileLog = "report"
+	}
+	if *profileLog != "" {
+		profile.Enable()
+		switch strings.ToLower(*profileLog) {
+		case "true", "1", "realtime":
+			// Realtime output to stderr
+			profile.SetRealtime(true)
+			fmt.Println("Profiling enabled: realtime output to stderr")
+		case "report", "summary":
+			// Summary only at the end
+			profile.SetRealtime(false)
+			fmt.Println("Profiling enabled: summary report at end")
+		default:
+			// File path specified
+			if err := profile.SetLogFile(*profileLog); err != nil {
+				log.Fatalf("Failed to open profile log file: %v", err)
+			}
+			profile.SetRealtime(true)
+			fmt.Printf("Profiling enabled: logging to %s\n", *profileLog)
+		}
+		defer func() {
+			profile.Report()
+			profile.Close()
+		}()
+	}
+
+	cpu.SetMaxThreads(*threads)
+
+	var gpuDevices []int
+	if strings.TrimSpace(*gpuDevicesFlag) != "" {
+		for _, part := range strings.Split(*gpuDevicesFlag, ",") {
+			part = strings.TrimSpace(part)
+			if part == "" {
+				continue
+			}
+			id, err := strconv.Atoi(part)
+			if err != nil {
+				log.Fatalf("invalid --gpu-devices entry %q: %v", part, err)
+			}
+			gpuDevices = append(gpuDevices, id)
+		}
+		if len(gpuDevices) == 0 {
+			log.Fatalf("invalid --gpu-devices: no devices parsed")
+		}
+	}
+
+	// Determine engine config
+	cfg := engine.Config{
+		GPULayers:  *nGPULayers,
+		GPUBudget:  *gpuBudget,
+		GPUDevices: gpuDevices,
+		UseMmap:    *useMmap,
+		CPUMoE:     *cpuMoE,
+	}
+
+	// Load Model
+	fmt.Printf("Loading model from %s...\n", *modelPath)
+	if *listTensors {
+		md, err := loader.LoadWithOptions(*modelPath, loader.LoadOptions{UseMmap: *useMmap})
+		if err != nil {
+			log.Fatalf("Failed to load model: %v", err)
+		}
+		defer md.Close()
+		names := make([]string, 0, len(md.Metadata.Tensors))
+		for name := range md.Metadata.Tensors {
+			names = append(names, name)
+		}
+		sort.Strings(names)
+		for _, name := range names {
+			info := md.Metadata.Tensors[name]
+			fmt.Printf("%s\t%s\t%v\t%d\n", name, info.DType.String(), info.Shape, info.Size)
+		}
+		return
+	}
+	eng, err := engine.Load(*modelPath, cfg)
+	if err != nil {
+		log.Fatalf("Failed to load model: %v", err)
+	}
+	defer eng.Close()
+
+	// Show device info
+	if device.CUDAAvailable() {
+		fmt.Println("CUDA available: yes")
+	} else {
+		fmt.Println("CUDA available: no (CPU only)")
+	}
+
+	modelConfig := eng.Model().Config()
+
+	// If layer-map is specified, use that for KV cache placement
+	var placements []tensor.DevicePlacement
+	if *layerMap != "" {
+		placements = parseLayerMap(modelConfig.NumLayers, *layerMap)
+	} else if eng.Dispatcher() != nil {
+		// Use dispatcher's placements
+		placements = make([]tensor.DevicePlacement, modelConfig.NumLayers)
+		for i := 0; i < modelConfig.NumLayers; i++ {
+			placements[i] = eng.Dispatcher().LayerPlacement(i)
+		}
+	}
+
+	fmt.Println("Model loaded successfully!")
+
+	// Load Tokenizer
+	var tok *tokenizer.Tokenizer
+	tokData, err := eng.Loader().GetTokenizerData()
+	if err == nil && len(tokData) > 0 {
+		fmt.Println("Found embedded tokenizer in model file.")
+		tok, err = tokenizer.LoadFromBytes(tokData)
+		if err != nil {
+			log.Printf("Warning: failed to load embedded tokenizer: %v", err)
+		}
+	}
+
+	if tok == nil {
+		modelDir := filepath.Dir(*modelPath)
+		tokPath := filepath.Join(modelDir, "tokenizer.json")
+
+		fmt.Printf("Loading tokenizer from %s...\n", tokPath)
+		tok, err = tokenizer.LoadFromJSON(tokPath)
+		if err != nil {
+			log.Printf("Warning: failed to load tokenizer: %v", err)
+		}
+	}
+
+	// Format prompt (optionally with chat template)
+	finalPrompt := *prompt
+	if *useChat {
+		messages := []chat.Message{{Role: "user", Content: *prompt}}
+		formatted, err := chat.RenderForArchitecture(modelConfig.Architecture, messages, chat.Options{
+			AddGenerationPrompt: true,
+			EnableThinking:      true,
+		})
+		if err != nil {
+			log.Fatalf("format prompt failed: %v", err)
+		}
+		finalPrompt = formatted
+		fmt.Printf("Formatted prompt:\n%s\n", finalPrompt)
+	}
+
+	// Server mode: start HTTP server and block.
+	// We do this after model+tokenizer are loaded so all flags (GPU, KV cache sizes, etc.) apply.
+	if *listen != "" {
+		*serverMode = true
+	}
+	if *serverMode {
+		addr := *listen
+		if addr == "" {
+			addr = net.JoinHostPort(*host, strconv.Itoa(*port))
+		}
+		// Ensure JSON is linked in this binary (avoid unused import when tags change)
+		_ = json.Valid
+		err := openai.Serve(eng, tok, modelConfig.Architecture, openai.Config{
+			Listen:           addr,
+			MaxSeqLen:        *maxSeq,
+			BlockSize:        *blockSize,
+			KVCacheCPU:       *kvCacheCPU,
+			EnableThinking:   false,
+			PrefillChunkSize: *prefillChunkSize,
+			MaxConcurrent:    *maxConcurrent,
+		})
+		if err != nil {
+			log.Fatalf("server failed: %v", err)
+		}
+		return
+	}
+
+	// Tokenize prompt
+	var ids []int
+	if tok != nil {
+		ids = tok.Encode(finalPrompt)
+		fmt.Printf("Tokens: %v\n", ids)
+	} else {
+		ids = []int{1, 2, 3}
+	}
+
+	// Initialize KV Cache
+	var kv model.KVCache
+	var pagedCache *kvcache.PagedKVCache
+	if modelConfig.Architecture == "KimiLinearForCausalLM" {
+		params := modelConfig.Params
+		lacRaw := params["linear_attn_config"]
+		lac, _ := lacRaw.(map[string]any)
+		kdaNumHeads := int(lac["num_heads"].(float64))
+		kdaHeadDim := int(lac["head_dim"].(float64))
+		kdaKernel := int(lac["short_conv_kernel_size"].(float64))
+		mlaNumHeads := int(params["num_attention_heads"].(float64))
+		qkNope := int(params["qk_nope_head_dim"].(float64))
+		qkRope := int(params["qk_rope_head_dim"].(float64))
+		vDim := int(params["v_head_dim"].(float64))
+		kimiCache, err := kimi_linear.NewKimiCache(modelConfig.NumLayers, kdaNumHeads, kdaHeadDim, kdaKernel, mlaNumHeads, qkNope+qkRope, vDim)
+		if err != nil {
+			log.Fatalf("KimiCache alloc failed: %v", err)
+		}
+		kv = kimiCache
+		fmt.Println("KV cache: KimiCache (CPU)")
+	} else {
+		// Default: enable GPU KV per-layer when ANY layer is on GPU (mixed offload supported),
+		// unless --kv-cache-cpu is specified.
+		kvDevice := tensor.CPU
+		if !*kvCacheCPU && device.CUDAAvailable() {
+			for i := 0; i < modelConfig.NumLayers && i < len(placements); i++ {
+				if placements[i].Normalize().Type == tensor.CUDA {
+					kvDevice = tensor.CUDA
+					break
+				}
+			}
+		}
+		switch kvDevice {
+		case tensor.CUDA:
+			fmt.Println("KV cache: mixed (per-layer)")
+		default:
+			fmt.Println("KV cache: CPU")
+		}
+
+		pool, err := kvcache.NewBlockPool(kvcache.BlockPoolConfig{
+			NumLayers:  modelConfig.NumLayers,
+			NumKVHeads: modelConfig.NumKVHeads,
+			HeadDim:    modelConfig.HeadDim,
+			BlockSize:  *blockSize,
+			NumBlocks:  (*maxSeq + *blockSize - 1) / (*blockSize),
+			Device:     kvDevice,
+			GPU:        0,
+			LayerPlacements: func() []tensor.DevicePlacement {
+				if kvDevice != tensor.CUDA || len(placements) != modelConfig.NumLayers {
+					return nil
+				}
+				out := make([]tensor.DevicePlacement, modelConfig.NumLayers)
+				for i := 0; i < modelConfig.NumLayers; i++ {
+					out[i] = placements[i].Normalize()
+				}
+				return out
+			}(),
+			Preallocate: kvDevice == tensor.CUDA,
+		})
+		if err != nil {
+			log.Fatalf("NewBlockPool failed: %v", err)
+		}
+		pagedCache = kvcache.NewPagedKVCache(pool, kvcache.PagedCacheConfig{
+			NumLayers:  modelConfig.NumLayers,
+			NumKVHeads: modelConfig.NumKVHeads,
+			HeadDim:    modelConfig.HeadDim,
+			BlockSize:  *blockSize,
+			MaxSeqLen:  *maxSeq,
+			Device:     kvDevice,
+			GPU:        0,
+		}, "run-model")
+		if _, err := pagedCache.AllocateForTokens(ids); err != nil {
+			pagedCache.Free()
+			log.Fatalf("PagedKVCache alloc failed: %v", err)
+		}
+		defer pagedCache.Free()
+		kv = pagedCache
+	}
+
+	// Preallocate scratch buffers once so prefill doesn't hit cudaMalloc churn.
+	runCtx := context.Background()
+	needWarmup := false
+	if device.CUDAAvailable() && eng.Dispatcher() != nil && cuda.Available() {
+		gpuSeen := make(map[int]struct{})
+		var gpus []int
+		for i := 0; i < modelConfig.NumLayers; i++ {
+			p := eng.Dispatcher().LayerPlacement(i).Normalize()
+			if p.Type != tensor.CUDA || p.GPU < 0 {
+				continue
+			}
+			if _, ok := gpuSeen[p.GPU]; ok {
+				continue
+			}
+			gpuSeen[p.GPU] = struct{}{}
+			gpus = append(gpus, p.GPU)
+		}
+		if len(gpus) > 0 {
+			const minScratchBytes = 8 << 20
+			var (
+				ss          *compute.ScratchSet
+				scratchErr  error
+				scratchSize = compute.DefaultScratchBytes
+			)
+			for scratchSize >= minScratchBytes {
+				var err error
+				ss, err = compute.NewScratchSet(gpus, scratchSize)
+				if err == nil {
+					break
+				}
+				scratchErr = err
+				ss = nil
+				scratchSize /= 2
+			}
+			if ss != nil {
+				defer ss.Free()
+				runCtx = compute.WithScratchSet(runCtx, ss)
+				runCtx = compute.WithScratch(runCtx, ss.Scratch(gpus[0]))
+				needWarmup = true
+				log.Printf("scratch: gpus=%v bytes=%d", gpus, scratchSize)
+			} else if scratchErr != nil {
+				log.Printf("scratch disabled (alloc failed): %v", scratchErr)
+			}
+		}
+	}
+	if needWarmup {
+		if _, err := eng.Forward(runCtx, createInputTensor([]int{0}), createPositionTensor(0, 1), nil); err != nil {
+			log.Fatalf("warmup forward failed: %v", err)
+		}
+		compute.LogWeightCacheSummary()
+	}
+
+	// Initialize Sampler
+	sampler := sample.New(sample.Config{
+		Temperature:       float32(*temperature),
+		TopK:              *topK,
+		TopP:              float32(*topP),
+		RepetitionPenalty: float32(*repPenalty),
+		Seed:              -1,
+	})
+
+	// Prefill prompt in chunks (llama.cpp eval batch size analogue).
+	chunk := *prefillChunkSize
+	if chunk <= 0 {
+		chunk = 512
+	}
+	var logits tensor.Tensor
+	for start := 0; start < len(ids); start += chunk {
+		end := start + chunk
+		if end > len(ids) {
+			end = len(ids)
+		}
+		part := ids[start:end]
+		input := createInputTensor(part)
+		positions := createPositionTensor(kv.SeqLen(), len(part))
+
+		before := kv.SeqLen()
+		profile.Start("Prefill/Forward")
+		out, err := eng.Forward(runCtx, input, positions, kv)
+		profile.End("Prefill/Forward")
+		if err != nil {
+			log.Fatalf("Prefill forward failed: %v", err)
+		}
+		logits = out
+		if kv.SeqLen() == before {
+			kv.Commit(len(part))
+		}
+	}
+	if logits == nil {
+		log.Fatalf("prefill produced nil logits")
+	}
+
+	// Sample first generated token
+	lastPartLen := len(ids) % chunk
+	if lastPartLen == 0 {
+		if chunk < len(ids) {
+			lastPartLen = chunk
+		} else {
+			lastPartLen = len(ids)
+		}
+	}
+	rowIdx := lastPartLen - 1
+	logitsSlice := getLogitsRowCPU(logits, rowIdx)
+	var nextToken int
+	if logitsSlice != nil {
+		profile.Start("Prefill/Sample")
+		nextToken = sampler.Sample(logitsSlice, ids)
+		profile.End("Prefill/Sample")
+	} else {
+		// CUDA path: take top-k (or argmax) from GPU, copy only small candidate list.
+		gpuLogits := logits.(*cuda.Tensor)
+		vocabSize := gpuLogits.Shape()[1]
+		view, err := gpuLogits.ViewAt(tensor.Shape{vocabSize}, uintptr(rowIdx*vocabSize*4))
+		if err != nil {
+			log.Fatalf("logits view failed: %v", err)
+		}
+		k := *topK
+		if *temperature == 0 {
+			k = 1
+		}
+		if k <= 0 {
+			// Semantics-preserving fallback: copy full logits row to CPU and use existing sampler.
+			host := make([]float32, vocabSize)
+			profile.Start("Prefill/LogitsD2H")
+			if err := view.CopyToHost(host); err != nil {
+				log.Fatalf("logits D2H failed: %v", err)
+			}
+			profile.End("Prefill/LogitsD2H")
+			profile.Start("Prefill/Sample")
+			nextToken = sampler.Sample(host, ids)
+			profile.End("Prefill/Sample")
+			goto sampledPrefill
+		}
+		recent := ids
+		if len(recent) > 64 {
+			recent = recent[len(recent)-64:]
+		}
+		repIDs := make([]int32, len(recent))
+		for i, t := range recent {
+			repIDs[i] = int32(t)
+		}
+		profile.Start("Prefill/TopK")
+		allIDs, allScores, blocks, err := cuda.TopKLogitsF32(view.Data().(unsafe.Pointer), vocabSize, repIDs, float32(*repPenalty), k, gpuLogits.GPU())
+		profile.End("Prefill/TopK")
+		if err != nil {
+			log.Fatalf("cuda topk failed: %v", err)
+		}
+		// Merge per-block candidates on CPU to get global top-k
+		cands := make([]struct {
+			id    int32
+			score float32
+		}, 0, blocks*k)
+		for i := 0; i < blocks*k; i++ {
+			if allIDs[i] < 0 {
+				continue
+			}
+			cands = append(cands, struct {
+				id    int32
+				score float32
+			}{id: allIDs[i], score: allScores[i]})
+		}
+		sort.Slice(cands, func(i, j int) bool { return cands[i].score > cands[j].score })
+		if len(cands) > k {
+			cands = cands[:k]
+		}
+		finalIDs := make([]int32, len(cands))
+		finalScores := make([]float32, len(cands))
+		for i := range cands {
+			finalIDs[i] = cands[i].id
+			finalScores[i] = cands[i].score
+		}
+		profile.Start("Prefill/Sample")
+		nextToken = sampler.SampleFromTopK(finalIDs, finalScores)
+		profile.End("Prefill/Sample")
+	}
+
+sampledPrefill:
+	if tok != nil {
+		fmt.Print(tok.Decode([]int{nextToken}))
+	}
+	ids = append(ids, nextToken)
+	if pagedCache != nil {
+		pagedCache.AppendToken(nextToken)
+	}
+
+	// Autoregressive generation with KV Cache
+	eosID := 151645 // <|im_end|>
+	if tok != nil {
+		eosID = tok.EosID()
+	}
+
+	startGen := time.Now()
+	genTokens := 0
+
+	for i := 1; i < *steps; i++ {
+		profile.TokenStart()
+		// Check for EOS
+		if nextToken == eosID {
+			profile.TokenEnd()
+			break
+		}
+
+		// Prepare single token input
+		input := createInputTensor([]int{nextToken})
+		currentPos := len(ids) - 1
+		positions := createPositionTensor(currentPos, 1)
+
+		profile.Start("Decode/Forward")
+		logits, err = eng.Forward(runCtx, input, positions, kv)
+		profile.End("Decode/Forward")
+		if err != nil {
+			log.Fatalf("Forward failed: %v", err)
+		}
+
+		// Sample with recent context for repetition penalty
+		logitsSlice = getLogitsRowCPU(logits, 0)
+		recentTokens := ids
+		if len(ids) > 64 {
+			recentTokens = ids[len(ids)-64:]
+		}
+		if logitsSlice != nil {
+			profile.Start("Decode/Sample")
+			nextToken = sampler.Sample(logitsSlice, recentTokens)
+			profile.End("Decode/Sample")
+		} else {
+			gpuLogits := logits.(*cuda.Tensor)
+			vocabSize := gpuLogits.Shape()[1]
+			view, err := gpuLogits.ViewAt(tensor.Shape{vocabSize}, 0)
+			if err != nil {
+				log.Fatalf("logits view failed: %v", err)
+			}
+			k := *topK
+			if *temperature == 0 {
+				k = 1
+			}
+			if k <= 0 {
+				host := make([]float32, vocabSize)
+				profile.Start("Decode/LogitsD2H")
+				if err := view.CopyToHost(host); err != nil {
+					log.Fatalf("logits D2H failed: %v", err)
+				}
+				profile.End("Decode/LogitsD2H")
+				profile.Start("Decode/Sample")
+				nextToken = sampler.Sample(host, recentTokens)
+				profile.End("Decode/Sample")
+				goto sampledDecode
+			}
+			repIDs := make([]int32, len(recentTokens))
+			for i, t := range recentTokens {
+				repIDs[i] = int32(t)
+			}
+			profile.Start("Decode/TopK")
+			allIDs, allScores, blocks, err := cuda.TopKLogitsF32(view.Data().(unsafe.Pointer), vocabSize, repIDs, float32(*repPenalty), k, gpuLogits.GPU())
+			profile.End("Decode/TopK")
+			if err != nil {
+				log.Fatalf("cuda topk failed: %v", err)
+			}
+			cands := make([]struct {
+				id    int32
+				score float32
+			}, 0, blocks*k)
+			for i := 0; i < blocks*k; i++ {
+				if allIDs[i] < 0 {
+					continue
+				}
+				cands = append(cands, struct {
+					id    int32
+					score float32
+				}{id: allIDs[i], score: allScores[i]})
+			}
+			sort.Slice(cands, func(i, j int) bool { return cands[i].score > cands[j].score })
+			if len(cands) > k {
+				cands = cands[:k]
+			}
+			finalIDs := make([]int32, len(cands))
+			finalScores := make([]float32, len(cands))
+			for i := range cands {
+				finalIDs[i] = cands[i].id
+				finalScores[i] = cands[i].score
+			}
+			profile.Start("Decode/Sample")
+			nextToken = sampler.SampleFromTopK(finalIDs, finalScores)
+			profile.End("Decode/Sample")
+		}
+
+	sampledDecode:
+
+		if tok != nil {
+			fmt.Print(tok.Decode([]int{nextToken}))
+		}
+		ids = append(ids, nextToken)
+		if pagedCache != nil {
+			pagedCache.AppendToken(nextToken)
+		}
+		genTokens++
+		profile.TokenEnd()
+	}
+	duration := time.Since(startGen)
+	fmt.Printf("\n\nDone. Generated %d tokens in %v (%.2f tok/s)\n", genTokens, duration, float64(genTokens)/duration.Seconds())
+}
+
+func createInputTensor(ids []int) tensor.Tensor {
+	t := cpu.NewTensor(tensor.Shape{len(ids)}, nil)
+	data := t.DataFloat32()
+	for i, id := range ids {
+		data[i] = float32(id)
+	}
+	return t
+}
+
+func createPositionTensor(start, count int) tensor.Tensor {
+	t := cpu.NewTensor(tensor.Shape{count}, nil)
+	data := t.DataFloat32()
+	for i := 0; i < count; i++ {
+		data[i] = float32(start + i)
+	}
+	return t
+}
+
+func getLogitsRowCPU(logits tensor.Tensor, row int) []float32 {
+	if _, ok := logits.(*cpu.Tensor); !ok {
+		return nil
+	}
+	data := logits.Data().(unsafe.Pointer)
+	shape := logits.Shape()
+	vocabSize := shape[1]
+	slice := unsafe.Slice((*float32)(data), shape.NumElements())
+	return slice[row*vocabSize : (row+1)*vocabSize]
+}
+
+// parseLayerMap parses a comma-separated placement string like
+// "0-9:gpu0,10-19:gpu1,20-:cpu" and returns per-layer placements.
+func parseLayerMap(numLayers int, spec string) []tensor.DevicePlacement {
+	placements := make([]tensor.DevicePlacement, numLayers)
+	for i := range placements {
+		placements[i] = tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}
+	}
+	if spec == "" {
+		return placements
+	}
+
+	entries := strings.Split(spec, ",")
+	for _, entry := range entries {
+		entry = strings.TrimSpace(entry)
+		if entry == "" {
+			continue
+		}
+		parts := strings.Split(entry, ":")
+		if len(parts) != 2 {
+			log.Printf("invalid layer-map entry %q, skipping", entry)
+			continue
+		}
+		rng, target := strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])
+		start, end := 0, numLayers-1
+		if rng != "" {
+			if strings.Contains(rng, "-") {
+				rp := strings.SplitN(rng, "-", 2)
+				if rp[0] != "" {
+					if v, err := strconv.Atoi(rp[0]); err == nil {
+						start = v
+					}
+				}
+				if rp[1] != "" {
+					if v, err := strconv.Atoi(rp[1]); err == nil {
+						end = v
+					}
+				}
+			} else if v, err := strconv.Atoi(rng); err == nil {
+				start, end = v, v
+			}
+		}
+		if start < 0 {
+			start = 0
+		}
+		if end >= numLayers {
+			end = numLayers - 1
+		}
+		var placement tensor.DevicePlacement
+		switch {
+		case strings.HasPrefix(strings.ToLower(target), "gpu"):
+			idStr := strings.TrimPrefix(strings.ToLower(target), "gpu")
+			id := 0
+			if idStr != "" {
+				if v, err := strconv.Atoi(idStr); err == nil {
+					id = v
+				}
+			}
+			placement = tensor.DevicePlacement{Type: tensor.CUDA, GPU: id}.Normalize()
+		case strings.ToLower(target) == "cpu":
+			placement = tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}
+		default:
+			log.Printf("unknown target %q, defaulting to CPU", target)
+			placement = tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}
+		}
+		for i := start; i <= end && i < numLayers; i++ {
+			placements[i] = placement
+		}
+	}
+	return placements
+}

+ 26 - 0
cmd/run-model/main_test.go

@@ -0,0 +1,26 @@
+package main
+
+import (
+	"testing"
+
+	"makarna/pkg/tensor"
+)
+
+func TestParseLayerMap(t *testing.T) {
+	placements := parseLayerMap(5, "0-1:gpu0,2-3:gpu1,4:cpu")
+	if len(placements) != 5 {
+		t.Fatalf("expected 5 placements, got %d", len(placements))
+	}
+	expectPlacement(t, placements[0], tensor.CUDA, 0)
+	expectPlacement(t, placements[1], tensor.CUDA, 0)
+	expectPlacement(t, placements[2], tensor.CUDA, 1)
+	expectPlacement(t, placements[3], tensor.CUDA, 1)
+	expectPlacement(t, placements[4], tensor.CPU, -1)
+}
+
+func expectPlacement(t *testing.T, p tensor.DevicePlacement, device tensor.DeviceType, gpu int) {
+	if p.Type != device || p.GPU != gpu {
+		t.Fatalf("expected device %v gpu %d, got %v gpu %d", device, gpu, p.Type, p.GPU)
+	}
+}
+

+ 8 - 0
go.mod

@@ -0,0 +1,8 @@
+module makarna
+
+go 1.22
+
+require (
+	github.com/nlpodyssey/safetensors v0.0.0-20250209183917-bfb01cc25f7c
+	golang.org/x/sys v0.26.0
+)

+ 12 - 0
go.sum

@@ -0,0 +1,12 @@
+github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
+github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/nlpodyssey/safetensors v0.0.0-20250209183917-bfb01cc25f7c h1:n1I/GOYbUEUA+GlHCQRDtX/TO8WjcaOye7tzNZ0reV4=
+github.com/nlpodyssey/safetensors v0.0.0-20250209183917-bfb01cc25f7c/go.mod h1:127qdlpPthu+2XqN6+n4evoaXw/w0qluNuDHr1tpBT4=
+github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
+github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
+golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
+golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
+gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
+gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

+ 145 - 0
pkg/backend/cpu/affinity_linux.go

@@ -0,0 +1,145 @@
+//go:build linux
+
+package cpu
+
+import (
+	"os"
+	"runtime"
+	"sort"
+	"strconv"
+	"strings"
+
+	"golang.org/x/sys/unix"
+)
+
+var (
+	affinityEnabled bool
+	affinitySet     unix.CPUSet
+)
+
+func init() {
+	// Optional: enable via env to keep default behavior unchanged.
+	//
+	// - MAKARNA_CPUSET: e.g. "0-15,32-47"
+	// - MAKARNA_NUMA_NODE: e.g. "0" (reads /sys/devices/system/node/node0/cpulist)
+	if err := initAffinityFromEnv(); err != nil {
+		affinityEnabled = false
+	}
+}
+
+func initAffinityFromEnv() error {
+	if cpus, ok := cpusFromEnv("MAKARNA_CPUSET"); ok {
+		return setAffinityCPUs(cpus)
+	}
+	if nodeStr := strings.TrimSpace(os.Getenv("MAKARNA_NUMA_NODE")); nodeStr != "" {
+		node, err := strconv.Atoi(nodeStr)
+		if err != nil || node < 0 {
+			return err
+		}
+		data, err := os.ReadFile("/sys/devices/system/node/node" + strconv.Itoa(node) + "/cpulist")
+		if err != nil {
+			return err
+		}
+		cpus, err := parseCPUList(strings.TrimSpace(string(data)))
+		if err != nil {
+			return err
+		}
+		return setAffinityCPUs(cpus)
+	}
+	return nil
+}
+
+func cpusFromEnv(key string) ([]int, bool) {
+	raw := strings.TrimSpace(os.Getenv(key))
+	if raw == "" {
+		return nil, false
+	}
+	cpus, err := parseCPUList(raw)
+	if err != nil {
+		return nil, false
+	}
+	return cpus, true
+}
+
+func parseCPUList(s string) ([]int, error) {
+	if s == "" {
+		return nil, strconv.ErrSyntax
+	}
+	var cpus []int
+	for _, part := range strings.Split(s, ",") {
+		part = strings.TrimSpace(part)
+		if part == "" {
+			continue
+		}
+		if lo, hi, ok := strings.Cut(part, "-"); ok {
+			start, err := strconv.Atoi(strings.TrimSpace(lo))
+			if err != nil {
+				return nil, err
+			}
+			end, err := strconv.Atoi(strings.TrimSpace(hi))
+			if err != nil {
+				return nil, err
+			}
+			if start < 0 || end < start {
+				return nil, strconv.ErrSyntax
+			}
+			for c := start; c <= end; c++ {
+				cpus = append(cpus, c)
+			}
+			continue
+		}
+		v, err := strconv.Atoi(part)
+		if err != nil {
+			return nil, err
+		}
+		if v < 0 {
+			return nil, strconv.ErrSyntax
+		}
+		cpus = append(cpus, v)
+	}
+	if len(cpus) == 0 {
+		return nil, strconv.ErrSyntax
+	}
+	sort.Ints(cpus)
+	cpus = compactInts(cpus)
+	return cpus, nil
+}
+
+func compactInts(xs []int) []int {
+	if len(xs) < 2 {
+		return xs
+	}
+	out := xs[:1]
+	for _, v := range xs[1:] {
+		if v == out[len(out)-1] {
+			continue
+		}
+		out = append(out, v)
+	}
+	return out
+}
+
+func setAffinityCPUs(cpus []int) error {
+	var set unix.CPUSet
+	set.Zero()
+	for _, c := range cpus {
+		set.Set(c)
+	}
+	affinitySet = set
+	affinityEnabled = true
+	return nil
+}
+
+// WithPinnedThread runs fn on a locked OS thread with optional CPU affinity set
+// (enabled via env). If affinity is not configured, it behaves like fn().
+func WithPinnedThread(fn func()) {
+	if !affinityEnabled {
+		fn()
+		return
+	}
+	runtime.LockOSThread()
+	defer runtime.UnlockOSThread()
+	_ = unix.SchedSetaffinity(0, &affinitySet)
+	fn()
+}
+

+ 7 - 0
pkg/backend/cpu/affinity_other.go

@@ -0,0 +1,7 @@
+//go:build !linux
+
+package cpu
+
+// WithPinnedThread is a no-op shim on non-Linux platforms.
+func WithPinnedThread(fn func()) { fn() }
+

+ 92 - 0
pkg/backend/cpu/cpufeat.go

@@ -0,0 +1,92 @@
+package cpu
+
+import "golang.org/x/sys/cpu"
+
+// Features captures the SIMD capabilities we care about for dispatching
+// CPU kernels. Values are populated once at init using the runtime cpuid
+// information from x/sys/cpu.
+type Features struct {
+	FMA      bool
+	AVX2     bool
+	AVX512F  bool
+	AVX512DQ bool
+	AVX512BW bool
+	AVX512VL bool
+
+	AVX512VNNI bool
+
+	AMXTile bool
+	AMXInt8 bool
+	AMXBF16 bool
+}
+
+var detected Features
+
+func init() {
+	detected = Features{
+		FMA:      cpu.X86.HasFMA,
+		AVX2:     cpu.X86.HasAVX2,
+		AVX512F:  cpu.X86.HasAVX512F,
+		AVX512DQ: cpu.X86.HasAVX512DQ,
+		AVX512BW: cpu.X86.HasAVX512BW,
+		AVX512VL: cpu.X86.HasAVX512VL,
+
+		AVX512VNNI: cpu.X86.HasAVX512VNNI,
+
+		AMXTile: cpu.X86.HasAMXTile,
+		AMXInt8: cpu.X86.HasAMXInt8,
+		AMXBF16: cpu.X86.HasAMXBF16,
+	}
+}
+
+// CPUFeatures returns the detected SIMD feature set.
+func CPUFeatures() Features {
+	return detected
+}
+
+// SupportsAVX512 reports whether the CPU supports the AVX-512 subset we use.
+// We require F, DQ, BW and VL to match ggml-style kernels.
+func SupportsAVX512() bool {
+	return detected.FMA && detected.AVX512F && detected.AVX512DQ && detected.AVX512BW && detected.AVX512VL
+}
+
+// SupportsAVX2 reports whether AVX2 is available.
+func SupportsAVX2() bool {
+	return detected.FMA && detected.AVX2
+}
+
+// SupportsAVX512VNNI reports whether the CPU supports AVX-512 VNNI instructions.
+func SupportsAVX512VNNI() bool {
+	return SupportsAVX512() && detected.AVX512VNNI
+}
+
+// SupportsAMXInt8 reports whether the CPU supports AMX INT8 tile ops.
+func SupportsAMXInt8() bool {
+	return detected.AMXTile && detected.AMXInt8
+}
+
+// SupportsAMXBF16 reports whether the CPU supports AMX BF16 tile ops.
+func SupportsAMXBF16() bool {
+	return detected.AMXTile && detected.AMXBF16
+}
+
+// SIMDLevel represents the best available SIMD tier.
+type SIMDLevel int
+
+const (
+	SIMDNone SIMDLevel = iota
+	SIMDAVX2
+	SIMDAVX512
+)
+
+// BestSIMD returns the highest SIMD level supported (AVX-512 over AVX2).
+func BestSIMD() SIMDLevel {
+	switch {
+	case SupportsAVX512():
+		return SIMDAVX512
+	case SupportsAVX2():
+		return SIMDAVX2
+	default:
+		return SIMDNone
+	}
+}

+ 130 - 0
pkg/backend/cpu/matmul/gemm_blocked.go

@@ -0,0 +1,130 @@
+package matmul
+
+import (
+	"sync"
+
+	"makarna/pkg/backend/cpu"
+)
+
+const (
+	minWorkPerWorker = 8192 // heuristic on M*N size before fanning out
+)
+
+// chooseWorkers returns a bounded worker count based on total work units and the
+// maximum allowed.
+func chooseWorkers(total, max int) int {
+	if max < 1 {
+		return 1
+	}
+	if total <= minWorkPerWorker {
+		return 1
+	}
+	need := (total + minWorkPerWorker - 1) / minWorkPerWorker
+	if need < 1 {
+		need = 1
+	}
+	if need > max {
+		need = max
+	}
+	return need
+}
+
+// chunkRanges splits [0,total) into at most parts chunks.
+func chunkRanges(total, parts int) [][2]int {
+	if parts < 1 {
+		parts = 1
+	}
+	chunk := (total + parts - 1) / parts
+	if chunk < 1 {
+		chunk = 1
+	}
+	var ranges [][2]int
+	for start := 0; start < total; start += chunk {
+		end := start + chunk
+		if end > total {
+			end = total
+		}
+		ranges = append(ranges, [2]int{start, end})
+	}
+	return ranges
+}
+
+// gemmFloat32Blocked computes C = A x B (row-major) where
+// A: MxK, B: NxK (row-major weights), C: MxN.
+// It uses a register-blocked 1x8 micro-kernel (when available) and
+// parallelizes across rows or columns depending on shape, without packing.
+func gemmFloat32Blocked(out, a, b []float32, M, K, N, maxWorkers int) {
+	// Use an approximate MAC count to decide parallelism. Using only M*N can
+	// underutilize CPU cores in decode (M==1) where K is large.
+	total := M * N * K
+	workers := chooseWorkers(total, maxWorkers)
+	if workers == 1 {
+		gemmFloat32Scalar(out, a, b, M, K, N)
+		return
+	}
+
+	// Decode-path specialization: M == 1, split across N
+	if M == 1 {
+		ranges := chunkRanges(N, workers)
+		var wg sync.WaitGroup
+		for _, r := range ranges {
+			wg.Add(1)
+			start, end := r[0], r[1]
+			go func(s, e int) {
+				defer wg.Done()
+				cpu.WithPinnedThread(func() {
+					gemvFloat32Range(out, a[:K], b, K, s, e)
+				})
+			}(start, end)
+		}
+		wg.Wait()
+		return
+	}
+
+	if M < workers && N > 1 {
+		ranges := chunkRanges(N, workers)
+		var wg sync.WaitGroup
+		for _, r := range ranges {
+			wg.Add(1)
+			start, end := r[0], r[1]
+			go func(s, e int) {
+				defer wg.Done()
+				cpu.WithPinnedThread(func() {
+					for m := 0; m < M; m++ {
+						row := a[m*K : (m+1)*K]
+						base := out[m*N : (m+1)*N]
+						gemvFloat32Range(base, row, b, K, s, e)
+					}
+				})
+			}(start, end)
+		}
+		wg.Wait()
+		return
+	}
+
+	ranges := chunkRanges(M, workers)
+	var wg sync.WaitGroup
+	for _, r := range ranges {
+		wg.Add(1)
+		start, end := r[0], r[1]
+		go func(s, e int) {
+			defer wg.Done()
+			cpu.WithPinnedThread(func() {
+				for m := s; m < e; m++ {
+					row := a[m*K : (m+1)*K]
+					base := out[m*N : (m+1)*N]
+					gemvFloat32Range(base, row, b, K, 0, N)
+				}
+			})
+		}(start, end)
+	}
+	wg.Wait()
+}
+
+func gemmFloat32Scalar(out, a, b []float32, M, K, N int) {
+	for m := 0; m < M; m++ {
+		row := a[m*K : (m+1)*K]
+		base := out[m*N : (m+1)*N]
+		gemvFloat32Range(base, row, b, K, 0, N)
+	}
+}

+ 202 - 0
pkg/backend/cpu/matmul/gemv_f32_tile8_avx2.s

@@ -0,0 +1,202 @@
+//go:build amd64
+// +build amd64
+
+#include "textflag.h"
+
+// func gemvF32Tile8AVX2(a *float32, b *float32, out *float32, K int)
+// Computes 8 independent dot products:
+//   out[t] = sum_{i=0..K-1} a[i] * b[t*K+i], for t=0..7
+// Vectorizes over K with AVX2/FMA and reuses each A vector across 8 outputs.
+TEXT ·gemvF32Tile8AVX2(SB), NOSPLIT, $96-32
+	// Preserve general-purpose registers (Go ABI + ABI wrappers).
+	MOVQ AX, 0(SP)
+	MOVQ BX, 8(SP)
+	MOVQ CX, 16(SP)
+	MOVQ DX, 24(SP)
+	MOVQ DI, 32(SP)
+	MOVQ SI, 40(SP)
+	MOVQ R8, 48(SP)
+	MOVQ R9, 56(SP)
+	MOVQ R10, 64(SP)
+	MOVQ R11, 72(SP)
+	MOVQ R12, 80(SP)
+	MOVQ R13, 88(SP)
+
+	MOVQ a+0(FP), DI
+	MOVQ b+8(FP), SI
+	MOVQ out+16(FP), DX
+	MOVQ K+24(FP), CX
+
+	// strideBytes = K * 4
+	MOVQ CX, BX
+	SHLQ $2, BX
+
+	// kMain = K &^ 7 (multiple of 8 floats)
+	ANDQ $-8, CX
+	JLE zero
+
+	// b1..b7 pointers
+	MOVQ SI, R8
+	ADDQ BX, R8
+	MOVQ R8, R9
+	ADDQ BX, R9
+	MOVQ R9, R10
+	ADDQ BX, R10
+	MOVQ R10, R11
+	ADDQ BX, R11
+	MOVQ R11, R12
+	ADDQ BX, R12
+	MOVQ R12, R13
+	ADDQ BX, R13
+	MOVQ R13, AX
+	ADDQ BX, AX
+
+	// zero accumulators Y0..Y7
+	VXORPS Y0, Y0, Y0
+	VXORPS Y1, Y1, Y1
+	VXORPS Y2, Y2, Y2
+	VXORPS Y3, Y3, Y3
+	VXORPS Y4, Y4, Y4
+	VXORPS Y5, Y5, Y5
+	VXORPS Y6, Y6, Y6
+	VXORPS Y7, Y7, Y7
+
+loop:
+	// load 8 floats from a
+	VMOVUPS (DI), Y8
+
+	// out0..out7 accumulate with shared A vector
+	VMOVUPS (SI), Y9
+	VFMADD231PS Y8, Y9, Y0
+	VMOVUPS (R8), Y9
+	VFMADD231PS Y8, Y9, Y1
+	VMOVUPS (R9), Y9
+	VFMADD231PS Y8, Y9, Y2
+	VMOVUPS (R10), Y9
+	VFMADD231PS Y8, Y9, Y3
+	VMOVUPS (R11), Y9
+	VFMADD231PS Y8, Y9, Y4
+	VMOVUPS (R12), Y9
+	VFMADD231PS Y8, Y9, Y5
+	VMOVUPS (R13), Y9
+	VFMADD231PS Y8, Y9, Y6
+	VMOVUPS (AX), Y9
+	VFMADD231PS Y8, Y9, Y7
+
+	// advance pointers
+	ADDQ $32, DI
+	ADDQ $32, SI
+	ADDQ $32, R8
+	ADDQ $32, R9
+	ADDQ $32, R10
+	ADDQ $32, R11
+	ADDQ $32, R12
+	ADDQ $32, R13
+	ADDQ $32, AX
+
+	SUBQ $8, CX
+	JNZ loop
+
+	// Reduce each accumulator to scalar and store.
+	// Y0 -> out[0]
+	VEXTRACTF128 $1, Y0, X8
+	VADDPS X8, X0, X0
+	VMOVHLPS X0, X0, X8
+	VADDPS X8, X0, X0
+	VPSHUFD $0xB1, X0, X8
+	VADDPS X8, X0, X0
+	MOVSS X0, 0(DX)
+
+	// Y1 -> out[1]
+	VEXTRACTF128 $1, Y1, X8
+	VADDPS X8, X1, X1
+	VMOVHLPS X1, X1, X8
+	VADDPS X8, X1, X1
+	VPSHUFD $0xB1, X1, X8
+	VADDPS X8, X1, X1
+	MOVSS X1, 4(DX)
+
+	// Y2 -> out[2]
+	VEXTRACTF128 $1, Y2, X8
+	VADDPS X8, X2, X2
+	VMOVHLPS X2, X2, X8
+	VADDPS X8, X2, X2
+	VPSHUFD $0xB1, X2, X8
+	VADDPS X8, X2, X2
+	MOVSS X2, 8(DX)
+
+	// Y3 -> out[3]
+	VEXTRACTF128 $1, Y3, X8
+	VADDPS X8, X3, X3
+	VMOVHLPS X3, X3, X8
+	VADDPS X8, X3, X3
+	VPSHUFD $0xB1, X3, X8
+	VADDPS X8, X3, X3
+	MOVSS X3, 12(DX)
+
+	// Y4 -> out[4]
+	VEXTRACTF128 $1, Y4, X8
+	VADDPS X8, X4, X4
+	VMOVHLPS X4, X4, X8
+	VADDPS X8, X4, X4
+	VPSHUFD $0xB1, X4, X8
+	VADDPS X8, X4, X4
+	MOVSS X4, 16(DX)
+
+	// Y5 -> out[5]
+	VEXTRACTF128 $1, Y5, X8
+	VADDPS X8, X5, X5
+	VMOVHLPS X5, X5, X8
+	VADDPS X8, X5, X5
+	VPSHUFD $0xB1, X5, X8
+	VADDPS X8, X5, X5
+	MOVSS X5, 20(DX)
+
+	// Y6 -> out[6]
+	VEXTRACTF128 $1, Y6, X8
+	VADDPS X8, X6, X6
+	VMOVHLPS X6, X6, X8
+	VADDPS X8, X6, X6
+	VPSHUFD $0xB1, X6, X8
+	VADDPS X8, X6, X6
+	MOVSS X6, 24(DX)
+
+	// Y7 -> out[7]
+	VEXTRACTF128 $1, Y7, X8
+	VADDPS X8, X7, X7
+	VMOVHLPS X7, X7, X8
+	VADDPS X8, X7, X7
+	VPSHUFD $0xB1, X7, X8
+	VADDPS X8, X7, X7
+	MOVSS X7, 28(DX)
+
+	VZEROUPPER
+	JMP epilogue
+
+zero:
+	VXORPS X0, X0, X0
+	MOVSS X0, 0(DX)
+	MOVSS X0, 4(DX)
+	MOVSS X0, 8(DX)
+	MOVSS X0, 12(DX)
+	MOVSS X0, 16(DX)
+	MOVSS X0, 20(DX)
+	MOVSS X0, 24(DX)
+	MOVSS X0, 28(DX)
+	VZEROUPPER
+	JMP epilogue
+
+epilogue:
+	MOVQ 0(SP), AX
+	MOVQ 8(SP), BX
+	MOVQ 16(SP), CX
+	MOVQ 24(SP), DX
+	MOVQ 32(SP), DI
+	MOVQ 40(SP), SI
+	MOVQ 48(SP), R8
+	MOVQ 56(SP), R9
+	MOVQ 64(SP), R10
+	MOVQ 72(SP), R11
+	MOVQ 80(SP), R12
+	MOVQ 88(SP), R13
+	RET

+ 215 - 0
pkg/backend/cpu/matmul/gemv_f32_tile8_avx512.s

@@ -0,0 +1,215 @@
+//go:build amd64
+// +build amd64
+
+#include "textflag.h"
+
+// func gemvF32Tile8AVX512(a *float32, b *float32, out *float32, K int)
+// Computes 8 independent dot products:
+//   out[t] = sum_{i=0..K-1} a[i] * b[t*K+i], for t=0..7
+// Vectorizes over K with AVX-512/FMA and reuses each A vector across 8 outputs.
+TEXT ·gemvF32Tile8AVX512(SB), NOSPLIT, $96-32
+	// Preserve general-purpose registers (Go ABI + ABI wrappers).
+	MOVQ AX, 0(SP)
+	MOVQ BX, 8(SP)
+	MOVQ CX, 16(SP)
+	MOVQ DX, 24(SP)
+	MOVQ DI, 32(SP)
+	MOVQ SI, 40(SP)
+	MOVQ R8, 48(SP)
+	MOVQ R9, 56(SP)
+	MOVQ R10, 64(SP)
+	MOVQ R11, 72(SP)
+	MOVQ R12, 80(SP)
+	MOVQ R13, 88(SP)
+
+	MOVQ a+0(FP), DI
+	MOVQ b+8(FP), SI
+	MOVQ out+16(FP), DX
+	MOVQ K+24(FP), CX
+
+	// strideBytes = K * 4
+	MOVQ CX, BX
+	SHLQ $2, BX
+
+	// kMain = K &^ 15 (multiple of 16 floats)
+	ANDQ $-16, CX
+	JLE zero
+
+	// b1..b7 pointers
+	MOVQ SI, R8
+	ADDQ BX, R8
+	MOVQ R8, R9
+	ADDQ BX, R9
+	MOVQ R9, R10
+	ADDQ BX, R10
+	MOVQ R10, R11
+	ADDQ BX, R11
+	MOVQ R11, R12
+	ADDQ BX, R12
+	MOVQ R12, R13
+	ADDQ BX, R13
+	MOVQ R13, AX
+	ADDQ BX, AX
+
+	// zero accumulators Z0..Z7
+	VXORPS Z0, Z0, Z0
+	VXORPS Z1, Z1, Z1
+	VXORPS Z2, Z2, Z2
+	VXORPS Z3, Z3, Z3
+	VXORPS Z4, Z4, Z4
+	VXORPS Z5, Z5, Z5
+	VXORPS Z6, Z6, Z6
+	VXORPS Z7, Z7, Z7
+
+loop:
+	VMOVUPS (DI), Z8
+
+	VMOVUPS (SI), Z9
+	VFMADD231PS Z8, Z9, Z0
+	VMOVUPS (R8), Z9
+	VFMADD231PS Z8, Z9, Z1
+	VMOVUPS (R9), Z9
+	VFMADD231PS Z8, Z9, Z2
+	VMOVUPS (R10), Z9
+	VFMADD231PS Z8, Z9, Z3
+	VMOVUPS (R11), Z9
+	VFMADD231PS Z8, Z9, Z4
+	VMOVUPS (R12), Z9
+	VFMADD231PS Z8, Z9, Z5
+	VMOVUPS (R13), Z9
+	VFMADD231PS Z8, Z9, Z6
+	VMOVUPS (AX), Z9
+	VFMADD231PS Z8, Z9, Z7
+
+	ADDQ $64, DI
+	ADDQ $64, SI
+	ADDQ $64, R8
+	ADDQ $64, R9
+	ADDQ $64, R10
+	ADDQ $64, R11
+	ADDQ $64, R12
+	ADDQ $64, R13
+	ADDQ $64, AX
+
+	SUBQ $16, CX
+	JNZ loop
+
+	// Reduce each accumulator to scalar and store.
+	// Z0 -> out[0]
+	VEXTRACTF32X8 $1, Z0, Y8
+	VADDPS Y8, Y0, Y0
+	VEXTRACTF128 $1, Y0, X8
+	VADDPS X8, X0, X0
+	VPSHUFD $0x4E, X0, X8
+	VADDPS X8, X0, X0
+	VPSHUFD $0xB1, X0, X8
+	VADDPS X8, X0, X0
+	MOVSS X0, 0(DX)
+
+	// Z1 -> out[1]
+	VEXTRACTF32X8 $1, Z1, Y8
+	VADDPS Y8, Y1, Y1
+	VEXTRACTF128 $1, Y1, X8
+	VADDPS X8, X1, X1
+	VPSHUFD $0x4E, X1, X8
+	VADDPS X8, X1, X1
+	VPSHUFD $0xB1, X1, X8
+	VADDPS X8, X1, X1
+	MOVSS X1, 4(DX)
+
+	// Z2 -> out[2]
+	VEXTRACTF32X8 $1, Z2, Y8
+	VADDPS Y8, Y2, Y2
+	VEXTRACTF128 $1, Y2, X8
+	VADDPS X8, X2, X2
+	VPSHUFD $0x4E, X2, X8
+	VADDPS X8, X2, X2
+	VPSHUFD $0xB1, X2, X8
+	VADDPS X8, X2, X2
+	MOVSS X2, 8(DX)
+
+	// Z3 -> out[3]
+	VEXTRACTF32X8 $1, Z3, Y8
+	VADDPS Y8, Y3, Y3
+	VEXTRACTF128 $1, Y3, X8
+	VADDPS X8, X3, X3
+	VPSHUFD $0x4E, X3, X8
+	VADDPS X8, X3, X3
+	VPSHUFD $0xB1, X3, X8
+	VADDPS X8, X3, X3
+	MOVSS X3, 12(DX)
+
+	// Z4 -> out[4]
+	VEXTRACTF32X8 $1, Z4, Y8
+	VADDPS Y8, Y4, Y4
+	VEXTRACTF128 $1, Y4, X8
+	VADDPS X8, X4, X4
+	VPSHUFD $0x4E, X4, X8
+	VADDPS X8, X4, X4
+	VPSHUFD $0xB1, X4, X8
+	VADDPS X8, X4, X4
+	MOVSS X4, 16(DX)
+
+	// Z5 -> out[5]
+	VEXTRACTF32X8 $1, Z5, Y8
+	VADDPS Y8, Y5, Y5
+	VEXTRACTF128 $1, Y5, X8
+	VADDPS X8, X5, X5
+	VPSHUFD $0x4E, X5, X8
+	VADDPS X8, X5, X5
+	VPSHUFD $0xB1, X5, X8
+	VADDPS X8, X5, X5
+	MOVSS X5, 20(DX)
+
+	// Z6 -> out[6]
+	VEXTRACTF32X8 $1, Z6, Y8
+	VADDPS Y8, Y6, Y6
+	VEXTRACTF128 $1, Y6, X8
+	VADDPS X8, X6, X6
+	VPSHUFD $0x4E, X6, X8
+	VADDPS X8, X6, X6
+	VPSHUFD $0xB1, X6, X8
+	VADDPS X8, X6, X6
+	MOVSS X6, 24(DX)
+
+	// Z7 -> out[7]
+	VEXTRACTF32X8 $1, Z7, Y8
+	VADDPS Y8, Y7, Y7
+	VEXTRACTF128 $1, Y7, X8
+	VADDPS X8, X7, X7
+	VPSHUFD $0x4E, X7, X8
+	VADDPS X8, X7, X7
+	VPSHUFD $0xB1, X7, X8
+	VADDPS X8, X7, X7
+	MOVSS X7, 28(DX)
+
+	VZEROUPPER
+	JMP epilogue
+
+zero:
+	VXORPS X0, X0, X0
+	MOVSS X0, 0(DX)
+	MOVSS X0, 4(DX)
+	MOVSS X0, 8(DX)
+	MOVSS X0, 12(DX)
+	MOVSS X0, 16(DX)
+	MOVSS X0, 20(DX)
+	MOVSS X0, 24(DX)
+	MOVSS X0, 28(DX)
+	VZEROUPPER
+	JMP epilogue
+
+epilogue:
+	MOVQ 0(SP), AX
+	MOVQ 8(SP), BX
+	MOVQ 16(SP), CX
+	MOVQ 24(SP), DX
+	MOVQ 32(SP), DI
+	MOVQ 40(SP), SI
+	MOVQ 48(SP), R8
+	MOVQ 56(SP), R9
+	MOVQ 64(SP), R10
+	MOVQ 72(SP), R11
+	MOVQ 80(SP), R12
+	MOVQ 88(SP), R13
+	RET

+ 109 - 0
pkg/backend/cpu/matmul/gemv_f32_tiled_amd64.go

@@ -0,0 +1,109 @@
+//go:build amd64
+
+package matmul
+
+import "makarna/pkg/backend/cpu"
+
+const f32NR = 8
+
+// gemvFloat32Range computes out[startN:endN] = aRow * B where:
+// - aRow is length K
+// - B is NxK row-major (weights)
+// - out is length at least endN
+//
+// It prefers a register-blocked 1x8 micro-kernel on AVX2/AVX-512.
+func gemvFloat32Range(out, aRow, b []float32, K, startN, endN int) {
+	if startN >= endN {
+		return
+	}
+
+	if cpu.SupportsAVX512() && K >= 16 {
+		gemvFloat32RangeAVX512(out, aRow, b, K, startN, endN)
+		return
+	}
+	if cpu.SupportsAVX2() && K >= 8 {
+		gemvFloat32RangeAVX2(out, aRow, b, K, startN, endN)
+		return
+	}
+
+	bOff := startN * K
+	for n := startN; n < endN; n++ {
+		out[n] = cpu.DotFloat32(aRow, b[bOff:bOff+K])
+		bOff += K
+	}
+}
+
+func gemvFloat32RangeAVX2(out, aRow, b []float32, K, startN, endN int) {
+	kMain := K &^ 7
+	if kMain <= 0 {
+		bOff := startN * K
+		for n := startN; n < endN; n++ {
+			out[n] = cpu.DotFloat32(aRow, b[bOff:bOff+K])
+			bOff += K
+		}
+		return
+	}
+
+	aTail := aRow[kMain:K]
+	kTail := K - kMain
+
+	bOff := startN * K
+	n := startN
+	for ; n+f32NR <= endN; n += f32NR {
+		gemvF32Tile8AVX2(&aRow[0], &b[bOff], &out[n], K)
+		if kTail != 0 {
+			for t := 0; t < f32NR; t++ {
+				rowOff := bOff + t*K + kMain
+				var tail float32
+				for i := 0; i < kTail; i++ {
+					tail += aTail[i] * b[rowOff+i]
+				}
+				out[n+t] += tail
+			}
+		}
+		bOff += f32NR * K
+	}
+	for ; n < endN; n++ {
+		out[n] = cpu.DotFloat32(aRow, b[bOff:bOff+K])
+		bOff += K
+	}
+}
+
+func gemvFloat32RangeAVX512(out, aRow, b []float32, K, startN, endN int) {
+	kMain := K &^ 15
+	if kMain <= 0 {
+		gemvFloat32RangeAVX2(out, aRow, b, K, startN, endN)
+		return
+	}
+
+	aTail := aRow[kMain:K]
+	kTail := K - kMain
+
+	bOff := startN * K
+	n := startN
+	for ; n+f32NR <= endN; n += f32NR {
+		gemvF32Tile8AVX512(&aRow[0], &b[bOff], &out[n], K)
+		if kTail != 0 {
+			for t := 0; t < f32NR; t++ {
+				rowOff := bOff + t*K + kMain
+				var tail float32
+				for i := 0; i < kTail; i++ {
+					tail += aTail[i] * b[rowOff+i]
+				}
+				out[n+t] += tail
+			}
+		}
+		bOff += f32NR * K
+	}
+	for ; n < endN; n++ {
+		out[n] = cpu.DotFloat32(aRow, b[bOff:bOff+K])
+		bOff += K
+	}
+}
+
+//go:noescape
+func gemvF32Tile8AVX2(a *float32, b *float32, out *float32, K int)
+
+//go:noescape
+func gemvF32Tile8AVX512(a *float32, b *float32, out *float32, K int)
+

+ 17 - 0
pkg/backend/cpu/matmul/gemv_f32_tiled_generic.go

@@ -0,0 +1,17 @@
+//go:build !amd64
+
+package matmul
+
+import "makarna/pkg/backend/cpu"
+
+func gemvFloat32Range(out, aRow, b []float32, K, startN, endN int) {
+	if startN >= endN {
+		return
+	}
+	bOff := startN * K
+	for n := startN; n < endN; n++ {
+		out[n] = cpu.DotFloat32(aRow, b[bOff:bOff+K])
+		bOff += K
+	}
+}
+

+ 68 - 0
pkg/backend/cpu/matmul/gemv_f32_tiled_test.go

@@ -0,0 +1,68 @@
+package matmul
+
+import (
+	"math"
+	"math/rand"
+	"testing"
+
+	"makarna/pkg/backend/cpu"
+)
+
+func TestGemvFloat32RangeMatchesScalar(t *testing.T) {
+	rng := rand.New(rand.NewSource(1))
+
+	cases := []struct {
+		K      int
+		N      int
+		startN int
+		endN   int
+	}{
+		{K: 3, N: 13, startN: 0, endN: 13},
+		{K: 33, N: 17, startN: 0, endN: 17},
+		{K: 33, N: 17, startN: 5, endN: 13},
+		{K: 32, N: 16, startN: 0, endN: 16},
+	}
+
+	for _, tc := range cases {
+		a := make([]float32, tc.K)
+		for i := range a {
+			a[i] = rng.Float32()*2 - 1
+		}
+		b := make([]float32, tc.N*tc.K)
+		for i := range b {
+			b[i] = rng.Float32()*2 - 1
+		}
+
+		want := make([]float32, tc.N)
+		for n := tc.startN; n < tc.endN; n++ {
+			var sum float32
+			rowOff := n * tc.K
+			for k := 0; k < tc.K; k++ {
+				sum += a[k] * b[rowOff+k]
+			}
+			want[n] = sum
+		}
+
+		// Sanity: cpu.DotFloat32 should broadly match the scalar sum.
+		for n := tc.startN; n < tc.endN; n++ {
+			rowOff := n * tc.K
+			gotCPU := cpu.DotFloat32(a, b[rowOff:rowOff+tc.K])
+			diff := math.Abs(float64(gotCPU - want[n]))
+			if diff > 1e-4 {
+				t.Fatalf("DotFloat32 mismatch K=%d n=%d: got=%v want=%v", tc.K, n, gotCPU, want[n])
+			}
+		}
+
+		got := make([]float32, tc.N)
+		gemvFloat32Range(got, a, b, tc.K, tc.startN, tc.endN)
+
+		const tol = 1e-4
+		for n := tc.startN; n < tc.endN; n++ {
+			diff := math.Abs(float64(got[n] - want[n]))
+			if diff > tol {
+				t.Fatalf("K=%d N=%d range=[%d,%d) n=%d: got=%v want=%v diff=%g",
+					tc.K, tc.N, tc.startN, tc.endN, n, got[n], want[n], diff)
+			}
+		}
+	}
+}

+ 13 - 0
pkg/backend/cpu/matmul/linear.go

@@ -0,0 +1,13 @@
+//go:build !cuda
+
+// Package matmul provides matrix multiplication operations
+package matmul
+
+import (
+	"makarna/pkg/backend/cpu"
+)
+
+// Linear dispatches to the CPU implementation when CUDA is disabled.
+func Linear(input, weight, output *cpu.Tensor) error {
+	return linearCPU(input, weight, output)
+}

+ 137 - 0
pkg/backend/cpu/matmul/linear_bench_test.go

@@ -0,0 +1,137 @@
+package matmul
+
+import (
+	"math/rand"
+	"testing"
+	"unsafe"
+
+	"makarna/pkg/backend/cpu"
+	"makarna/pkg/tensor"
+)
+
+func BenchmarkLinearF32Decode(b *testing.B) {
+	// Simulate single-token decode: M=1, moderate N
+	M, K, N := 1, 512, 1024
+	in := make([]float32, M*K)
+	for i := range in {
+		in[i] = rand.Float32()
+	}
+	w := make([]float32, N*K)
+	for i := range w {
+		w[i] = rand.Float32()
+	}
+	out := make([]float32, M*N)
+
+	inT := cpu.NewTensor(tensor.Shape{M, K}, in)
+	wT := cpu.NewTensor(tensor.Shape{N, K}, w)
+	outT := cpu.NewTensor(tensor.Shape{M, N}, out)
+
+	b.ReportAllocs()
+	b.ResetTimer()
+	for i := 0; i < b.N; i++ {
+		Linear(inT, wT, outT)
+	}
+}
+
+func BenchmarkLinearF32Prefill(b *testing.B) {
+	// Prefill-style: larger M
+	M, K, N := 32, 512, 1024
+	in := make([]float32, M*K)
+	for i := range in {
+		in[i] = rand.Float32()
+	}
+	w := make([]float32, N*K)
+	for i := range w {
+		w[i] = rand.Float32()
+	}
+	out := make([]float32, M*N)
+
+	inT := cpu.NewTensor(tensor.Shape{M, K}, in)
+	wT := cpu.NewTensor(tensor.Shape{N, K}, w)
+	outT := cpu.NewTensor(tensor.Shape{M, N}, out)
+
+	b.ReportAllocs()
+	b.ResetTimer()
+	for i := 0; i < b.N; i++ {
+		Linear(inT, wT, outT)
+	}
+}
+
+func BenchmarkLinearQ4KDecode(b *testing.B) {
+	M, K, N := 1, 512, 1024 // K multiple of 256
+	in := make([]float32, M*K)
+	for i := range in {
+		in[i] = rand.Float32()
+	}
+	weightBlocks := makeQ4Blocks(N, K)
+	out := make([]float32, M*N)
+
+	inT := cpu.NewTensor(tensor.Shape{M, K}, in)
+	wT := makeQ4Tensor(b, tensor.Shape{N, K}, weightBlocks)
+	outT := cpu.NewTensor(tensor.Shape{M, N}, out)
+
+	b.ReportAllocs()
+	b.ResetTimer()
+	for i := 0; i < b.N; i++ {
+		Linear(inT, wT, outT)
+	}
+}
+
+func BenchmarkLinearQ4KPrefill(b *testing.B) {
+	M, K, N := 16, 512, 512
+	in := make([]float32, M*K)
+	for i := range in {
+		in[i] = rand.Float32()
+	}
+	weightBlocks := makeQ4Blocks(N, K)
+	out := make([]float32, M*N)
+
+	inT := cpu.NewTensor(tensor.Shape{M, K}, in)
+	wT := makeQ4Tensor(b, tensor.Shape{N, K}, weightBlocks)
+	outT := cpu.NewTensor(tensor.Shape{M, N}, out)
+
+	b.ReportAllocs()
+	b.ResetTimer()
+	for i := 0; i < b.N; i++ {
+		Linear(inT, wT, outT)
+	}
+}
+
+func fillQ4Blocks(bs []tensor.BlockQ4_K) {
+	const fp16One = uint16(0x3c00)
+	for i := range bs {
+		b := &bs[i]
+		b.D = fp16One
+		b.DMin = 0
+		for j := range b.Scales {
+			b.Scales[j] = 1
+		}
+		for j := range b.QS {
+			b.QS[j] = 0x11 // two nibbles set to 1
+		}
+	}
+}
+
+func makeQ4Blocks(N, K int) []tensor.BlockQ4_K {
+	blocksPerRow := K / tensor.QK_K
+	weightBlocks := make([]tensor.BlockQ4_K, N*blocksPerRow)
+	fillQ4Blocks(weightBlocks)
+	return weightBlocks
+}
+
+func makeQ4Tensor(tb testing.TB, shape tensor.Shape, blocks []tensor.BlockQ4_K) *cpu.Tensor {
+	if len(blocks) == 0 {
+		t, err := cpu.NewTensorFromBytes(shape, tensor.Q4_K, []byte{})
+		if err != nil {
+			tb.Fatalf("makeQ4Tensor empty: %v", err)
+		}
+		return t
+	}
+	blockSize := int(unsafe.Sizeof(tensor.BlockQ4_K{}))
+	buf := unsafe.Slice((*byte)(unsafe.Pointer(&blocks[0])), len(blocks)*blockSize)
+	t, err := cpu.NewTensorFromBytes(shape, tensor.Q4_K, buf)
+	if err != nil {
+		tb.Fatalf("makeQ4Tensor: %v", err)
+	}
+	return t
+}

+ 71 - 0
pkg/backend/cpu/matmul/linear_cuda.go

@@ -0,0 +1,71 @@
+//go:build cuda
+
+package matmul
+
+import (
+	"fmt"
+
+	"makarna/pkg/backend/cpu"
+	"makarna/pkg/backend/cuda"
+	"makarna/pkg/tensor"
+)
+
+// Linear offloads float32 matmul to CUDA when built with the cuda tag.
+// For non-float32 weights, it falls back to the CPU path.
+func Linear(input, weight, output *cpu.Tensor) error {
+	// Fall back for non-float32 weights or inputs.
+	if weight.DType() != tensor.Float32 || input.DType() != tensor.Float32 || output.DType() != tensor.Float32 {
+		return linearCPU(input, weight, output)
+	}
+
+	inShape := input.Shape()
+	wShape := weight.Shape()
+	if len(inShape) != 2 || len(wShape) != 2 {
+		return fmt.Errorf("linear: expected 2D inputs, got input %v, weight %v", inShape, wShape)
+	}
+
+	M := inShape[0]
+	K := inShape[1]
+	N := wShape[0]
+	if wShape[1] != K {
+		return fmt.Errorf("linear: shape mismatch: input [*, %d] vs weight [%d, %d]", K, N, wShape[1])
+	}
+
+	// Allocate CUDA buffers
+	a, err := cuda.NewTensor(tensor.Shape{M, K}, tensor.Float32, 0)
+	if err != nil {
+		return err
+	}
+	// Weight stays row-major [N, K]
+	b, err := cuda.NewTensor(tensor.Shape{N, K}, tensor.Float32, 0)
+	if err != nil {
+		return err
+	}
+	c, err := cuda.NewTensor(tensor.Shape{M, N}, tensor.Float32, 0)
+	if err != nil {
+		return err
+	}
+
+	// Copy input
+	if err := a.CopyFrom(input.DataFloat32()); err != nil {
+		return fmt.Errorf("linear: copy A failed: %w", err)
+	}
+
+	// Copy weight as-is (row-major [N, K]); CUDA kernel handles NT
+	if err := b.CopyFrom(weight.DataFloat32()); err != nil {
+		return fmt.Errorf("linear: copy B failed: %w", err)
+	}
+
+	// MatMul: c = a @ b
+	if err := a.MatMul(b, c); err != nil {
+		return fmt.Errorf("linear: cuda matmul failed: %w", err)
+	}
+
+	// Copy back to CPU output
+	hostC := make([]float32, M*N)
+	if err := c.CopyToHost(hostC); err != nil {
+		return fmt.Errorf("linear: copy C failed: %w", err)
+	}
+	copy(output.DataFloat32(), hostC)
+	return nil
+}

+ 50 - 0
pkg/backend/cpu/matmul/linear_cuda_test.go

@@ -0,0 +1,50 @@
+//go:build cuda
+
+package matmul
+
+import (
+	"testing"
+
+	"makarna/pkg/backend/cpu"
+	"makarna/pkg/tensor"
+)
+
+func TestLinearCudaMatchesCPU(t *testing.T) {
+	// Small matrix to compare CPU vs CUDA paths.
+	M, K, N := 4, 8, 3
+	aCPU := cpu.NewTensor(tensor.Shape{M, K}, nil)
+	wCPU := cpu.NewTensor(tensor.Shape{N, K}, nil)
+	outCPU := cpu.NewTensor(tensor.Shape{M, N}, nil)
+	outCUDA := cpu.NewTensor(tensor.Shape{M, N}, nil)
+
+	fillSeq(aCPU.DataFloat32())
+	fillSeq(wCPU.DataFloat32())
+
+	if err := linearCPU(aCPU, wCPU, outCPU); err != nil {
+		t.Fatalf("cpu linear failed: %v", err)
+	}
+	if err := Linear(aCPU, wCPU, outCUDA); err != nil {
+		t.Fatalf("cuda linear failed: %v", err)
+	}
+
+	for i, v := range outCPU.DataFloat32() {
+		got := outCUDA.DataFloat32()[i]
+		if diff := abs32(v - got); diff > 1e-4 {
+			t.Fatalf("mismatch at %d: cpu=%f cuda=%f", i, v, got)
+		}
+	}
+}
+
+func fillSeq(dst []float32) {
+	for i := range dst {
+		dst[i] = float32(i%7 + 1)
+	}
+}
+
+func abs32(v float32) float32 {
+	if v < 0 {
+		return -v
+	}
+	return v
+}
+

+ 688 - 0
pkg/backend/cpu/matmul/linear_shared.go

@@ -0,0 +1,688 @@
+package matmul
+
+import (
+	"fmt"
+	"sync"
+
+	"makarna/pkg/backend/cpu"
+	"makarna/pkg/tensor"
+)
+
+// linearCPU contains the original CPU implementations for all supported
+// weight dtypes. Both CPU-only and CUDA-enabled builds reuse this.
+func linearCPU(input, weight, output *cpu.Tensor) error {
+	inShape := input.Shape()
+	wShape := weight.Shape()
+
+	// Validate dimensions
+	if len(inShape) != 2 || len(wShape) != 2 {
+		return fmt.Errorf("linear: expected 2D inputs, got input %v, weight %v", inShape, wShape)
+	}
+
+	M := inShape[0]
+	K := inShape[1]
+	N := wShape[0]
+
+	if wShape[1] != K {
+		return fmt.Errorf("linear: shape mismatch: input [*, %d] vs weight [%d, %d]", K, N, wShape[1])
+	}
+
+	inData := input.DataFloat32()
+	outData := output.DataFloat32()
+	workers := cpu.MaxThreads()
+
+	switch weight.DType() {
+	case tensor.Float32:
+		wData := weight.DataFloat32()
+		gemmFloat32Blocked(outData, inData, wData, M, K, N, workers)
+
+	case tensor.Q4_K:
+		wData := weight.DataQ4_K()
+		if K%256 != 0 {
+			return fmt.Errorf("linear: Q4_K weight K dimension %d must be multiple of 256", K)
+		}
+		wParams := tensor.GetQ4KDotParams(wData)
+
+		blocksPerRow := K / 256
+		work := M * N * K
+		use := chooseWorkers(work, workers)
+		if use == 1 {
+			if M == 1 {
+				q4kGemvDecodeTiled(outData[:N], inData[:K], wData, wParams, N, blocksPerRow, 0, N)
+				return nil
+			}
+			for m := 0; m < M; m++ {
+				for n := 0; n < N; n++ {
+					var sum float32
+					for b := 0; b < blocksPerRow; b++ {
+						inOffset := m*K + b*256
+						wBlockIdx := n*blocksPerRow + b
+						block := &wData[wBlockIdx]
+						p := &wParams[wBlockIdx]
+						sum += tensor.DotQ4_K_Params(block, p, inData[inOffset:inOffset+256])
+					}
+					outData[m*N+n] = sum
+				}
+			}
+			return nil
+		}
+		var wg sync.WaitGroup
+		if M == 1 {
+			for _, r := range chunkRanges(N, use) {
+				wg.Add(1)
+				start, end := r[0], r[1]
+				go func(s, e int) {
+					defer wg.Done()
+					q4kGemvDecodeTiled(outData[:N], inData[:K], wData, wParams, N, blocksPerRow, s, e)
+				}(start, end)
+			}
+			wg.Wait()
+			return nil
+		}
+		if M < use {
+			for _, r := range chunkRanges(N, use) {
+				wg.Add(1)
+				start, end := r[0], r[1]
+				go func(s, e int) {
+					defer wg.Done()
+					for n := s; n < e; n++ {
+						for m := 0; m < M; m++ {
+							var sum float32
+							for b := 0; b < blocksPerRow; b++ {
+								inOffset := m*K + b*256
+								wBlockIdx := n*blocksPerRow + b
+								block := &wData[wBlockIdx]
+								p := &wParams[wBlockIdx]
+								sum += tensor.DotQ4_K_Params(block, p, inData[inOffset:inOffset+256])
+							}
+							outData[m*N+n] = sum
+						}
+					}
+				}(start, end)
+			}
+			wg.Wait()
+			return nil
+		}
+		for _, r := range chunkRanges(M, use) {
+			wg.Add(1)
+			start, end := r[0], r[1]
+			go func(s, e int) {
+				defer wg.Done()
+				for m := s; m < e; m++ {
+					for n := 0; n < N; n++ {
+						var sum float32
+						for b := 0; b < blocksPerRow; b++ {
+							inOffset := m*K + b*256
+							wBlockIdx := n*blocksPerRow + b
+							block := &wData[wBlockIdx]
+							p := &wParams[wBlockIdx]
+							sum += tensor.DotQ4_K_Params(block, p, inData[inOffset:inOffset+256])
+						}
+						outData[m*N+n] = sum
+					}
+				}
+			}(start, end)
+		}
+		wg.Wait()
+
+	case tensor.Q8_K:
+		wData := weight.DataQ8_K()
+		if K%256 != 0 {
+			return fmt.Errorf("linear: Q8_K weight K dimension %d must be multiple of 256", K)
+		}
+
+		blocksPerRow := K / 256
+		work := M * N * K
+		use := chooseWorkers(work, workers)
+		if use == 1 {
+			if M == 1 {
+				q8kGemvDecodeTiled(outData[:N], inData[:K], wData, N, blocksPerRow, 0, N)
+				return nil
+			}
+			for m := 0; m < M; m++ {
+				for n := 0; n < N; n++ {
+					var sum float32
+					for b := 0; b < blocksPerRow; b++ {
+						inOffset := m*K + b*256
+						wBlockIdx := n*blocksPerRow + b
+						block := &wData[wBlockIdx]
+						sum += tensor.DotQ8_K(block, inData[inOffset:inOffset+256])
+					}
+					outData[m*N+n] = sum
+				}
+			}
+			return nil
+		}
+		var wg sync.WaitGroup
+		if M == 1 {
+			for _, r := range chunkRanges(N, use) {
+				wg.Add(1)
+				start, end := r[0], r[1]
+				go func(s, e int) {
+					defer wg.Done()
+					q8kGemvDecodeTiled(outData[:N], inData[:K], wData, N, blocksPerRow, s, e)
+				}(start, end)
+			}
+			wg.Wait()
+			return nil
+		}
+		if M < use {
+			for _, r := range chunkRanges(N, use) {
+				wg.Add(1)
+				start, end := r[0], r[1]
+				go func(s, e int) {
+					defer wg.Done()
+					for n := s; n < e; n++ {
+						for m := 0; m < M; m++ {
+							var sum float32
+							for b := 0; b < blocksPerRow; b++ {
+								inOffset := m*K + b*256
+								wBlockIdx := n*blocksPerRow + b
+								block := &wData[wBlockIdx]
+								sum += tensor.DotQ8_K(block, inData[inOffset:inOffset+256])
+							}
+							outData[m*N+n] = sum
+						}
+					}
+				}(start, end)
+			}
+			wg.Wait()
+			return nil
+		}
+		for _, r := range chunkRanges(M, use) {
+			wg.Add(1)
+			start, end := r[0], r[1]
+			go func(s, e int) {
+				defer wg.Done()
+				for m := s; m < e; m++ {
+					for n := 0; n < N; n++ {
+						var sum float32
+						for b := 0; b < blocksPerRow; b++ {
+							inOffset := m*K + b*256
+							wBlockIdx := n*blocksPerRow + b
+							block := &wData[wBlockIdx]
+							sum += tensor.DotQ8_K(block, inData[inOffset:inOffset+256])
+						}
+						outData[m*N+n] = sum
+					}
+				}
+			}(start, end)
+		}
+		wg.Wait()
+
+	case tensor.Q3_K:
+		wData := weight.DataQ3_K()
+		if K%256 != 0 {
+			return fmt.Errorf("linear: Q3_K weight K dimension %d must be multiple of 256", K)
+		}
+		wParams := tensor.GetQ3KDotParams(wData)
+
+		blocksPerRow := K / 256
+		work := M * N * K
+		use := chooseWorkers(work, workers)
+		if use == 1 {
+			if M == 1 {
+				q3kGemvDecodeTiled(outData[:N], inData[:K], wData, wParams, N, blocksPerRow, 0, N)
+				return nil
+			}
+			for m := 0; m < M; m++ {
+				for n := 0; n < N; n++ {
+					var sum float32
+					for b := 0; b < blocksPerRow; b++ {
+						inOffset := m*K + b*256
+						wBlockIdx := n*blocksPerRow + b
+						block := &wData[wBlockIdx]
+						p := &wParams[wBlockIdx]
+						sum += tensor.DotQ3_K_Params(block, p, inData[inOffset:inOffset+256])
+					}
+					outData[m*N+n] = sum
+				}
+			}
+			return nil
+		}
+		var wg sync.WaitGroup
+		if M == 1 {
+			for _, r := range chunkRanges(N, use) {
+				wg.Add(1)
+				start, end := r[0], r[1]
+				go func(s, e int) {
+					defer wg.Done()
+					q3kGemvDecodeTiled(outData[:N], inData[:K], wData, wParams, N, blocksPerRow, s, e)
+				}(start, end)
+			}
+			wg.Wait()
+			return nil
+		}
+		if M < use {
+			for _, r := range chunkRanges(N, use) {
+				wg.Add(1)
+				start, end := r[0], r[1]
+				go func(s, e int) {
+					defer wg.Done()
+					for n := s; n < e; n++ {
+						for m := 0; m < M; m++ {
+							var sum float32
+							for b := 0; b < blocksPerRow; b++ {
+								inOffset := m*K + b*256
+								wBlockIdx := n*blocksPerRow + b
+								block := &wData[wBlockIdx]
+								p := &wParams[wBlockIdx]
+								sum += tensor.DotQ3_K_Params(block, p, inData[inOffset:inOffset+256])
+							}
+							outData[m*N+n] = sum
+						}
+					}
+				}(start, end)
+			}
+			wg.Wait()
+			return nil
+		}
+		for _, r := range chunkRanges(M, use) {
+			wg.Add(1)
+			start, end := r[0], r[1]
+			go func(s, e int) {
+				defer wg.Done()
+				for m := s; m < e; m++ {
+					for n := 0; n < N; n++ {
+						var sum float32
+						for b := 0; b < blocksPerRow; b++ {
+							inOffset := m*K + b*256
+							wBlockIdx := n*blocksPerRow + b
+							block := &wData[wBlockIdx]
+							p := &wParams[wBlockIdx]
+							sum += tensor.DotQ3_K_Params(block, p, inData[inOffset:inOffset+256])
+						}
+						outData[m*N+n] = sum
+					}
+				}
+			}(start, end)
+		}
+		wg.Wait()
+
+	case tensor.Q5_K:
+		wData := weight.DataQ5_K()
+		if K%256 != 0 {
+			return fmt.Errorf("linear: Q5_K weight K dimension %d must be multiple of 256", K)
+		}
+		wParams := tensor.GetQ5KDotParams(wData)
+
+		blocksPerRow := K / 256
+		work := M * N * K
+		use := chooseWorkers(work, workers)
+		if use == 1 {
+			if M == 1 {
+				q5kGemvDecodeTiled(outData[:N], inData[:K], wData, wParams, N, blocksPerRow, 0, N)
+				return nil
+			}
+			for m := 0; m < M; m++ {
+				for n := 0; n < N; n++ {
+					var sum float32
+					for b := 0; b < blocksPerRow; b++ {
+						inOffset := m*K + b*256
+						wBlockIdx := n*blocksPerRow + b
+						block := &wData[wBlockIdx]
+						p := &wParams[wBlockIdx]
+						sum += tensor.DotQ5_K_Params(block, p, inData[inOffset:inOffset+256])
+					}
+					outData[m*N+n] = sum
+				}
+			}
+			return nil
+		}
+		var wg sync.WaitGroup
+		if M == 1 {
+			for _, r := range chunkRanges(N, use) {
+				wg.Add(1)
+				start, end := r[0], r[1]
+				go func(s, e int) {
+					defer wg.Done()
+					q5kGemvDecodeTiled(outData[:N], inData[:K], wData, wParams, N, blocksPerRow, s, e)
+				}(start, end)
+			}
+			wg.Wait()
+			return nil
+		}
+		if M < use {
+			for _, r := range chunkRanges(N, use) {
+				wg.Add(1)
+				start, end := r[0], r[1]
+				go func(s, e int) {
+					defer wg.Done()
+					for n := s; n < e; n++ {
+						for m := 0; m < M; m++ {
+							var sum float32
+							for b := 0; b < blocksPerRow; b++ {
+								inOffset := m*K + b*256
+								wBlockIdx := n*blocksPerRow + b
+								block := &wData[wBlockIdx]
+								p := &wParams[wBlockIdx]
+								sum += tensor.DotQ5_K_Params(block, p, inData[inOffset:inOffset+256])
+							}
+							outData[m*N+n] = sum
+						}
+					}
+				}(start, end)
+			}
+			wg.Wait()
+			return nil
+		}
+		for _, r := range chunkRanges(M, use) {
+			wg.Add(1)
+			start, end := r[0], r[1]
+			go func(s, e int) {
+				defer wg.Done()
+				for m := s; m < e; m++ {
+					for n := 0; n < N; n++ {
+						var sum float32
+						for b := 0; b < blocksPerRow; b++ {
+							inOffset := m*K + b*256
+							wBlockIdx := n*blocksPerRow + b
+							block := &wData[wBlockIdx]
+							p := &wParams[wBlockIdx]
+							sum += tensor.DotQ5_K_Params(block, p, inData[inOffset:inOffset+256])
+						}
+						outData[m*N+n] = sum
+					}
+				}
+			}(start, end)
+		}
+		wg.Wait()
+
+	case tensor.Q6_K:
+		wData := weight.DataQ6_K()
+		if K%256 != 0 {
+			return fmt.Errorf("linear: Q6_K weight K dimension %d must be multiple of 256", K)
+		}
+		wParams := tensor.GetQ6KDotParams(wData)
+
+		blocksPerRow := K / 256
+		work := M * N * K
+		use := chooseWorkers(work, workers)
+		if use == 1 {
+			if M == 1 {
+				q6kGemvDecodeTiled(outData[:N], inData[:K], wData, wParams, N, blocksPerRow, 0, N)
+				return nil
+			}
+			for m := 0; m < M; m++ {
+				for n := 0; n < N; n++ {
+					var sum float32
+					for b := 0; b < blocksPerRow; b++ {
+						inOffset := m*K + b*256
+						wBlockIdx := n*blocksPerRow + b
+						block := &wData[wBlockIdx]
+						p := &wParams[wBlockIdx]
+						sum += tensor.DotQ6_K_Params(block, p, inData[inOffset:inOffset+256])
+					}
+					outData[m*N+n] = sum
+				}
+			}
+			return nil
+		}
+		var wg sync.WaitGroup
+		if M == 1 {
+			for _, r := range chunkRanges(N, use) {
+				wg.Add(1)
+				start, end := r[0], r[1]
+				go func(s, e int) {
+					defer wg.Done()
+					q6kGemvDecodeTiled(outData[:N], inData[:K], wData, wParams, N, blocksPerRow, s, e)
+				}(start, end)
+			}
+			wg.Wait()
+			return nil
+		}
+		if M < use {
+			for _, r := range chunkRanges(N, use) {
+				wg.Add(1)
+				start, end := r[0], r[1]
+				go func(s, e int) {
+					defer wg.Done()
+					for n := s; n < e; n++ {
+						for m := 0; m < M; m++ {
+							var sum float32
+							for b := 0; b < blocksPerRow; b++ {
+								inOffset := m*K + b*256
+								wBlockIdx := n*blocksPerRow + b
+								block := &wData[wBlockIdx]
+								p := &wParams[wBlockIdx]
+								sum += tensor.DotQ6_K_Params(block, p, inData[inOffset:inOffset+256])
+							}
+							outData[m*N+n] = sum
+						}
+					}
+				}(start, end)
+			}
+			wg.Wait()
+			return nil
+		}
+		for _, r := range chunkRanges(M, use) {
+			wg.Add(1)
+			start, end := r[0], r[1]
+			go func(s, e int) {
+				defer wg.Done()
+				for m := s; m < e; m++ {
+					for n := 0; n < N; n++ {
+						var sum float32
+						for b := 0; b < blocksPerRow; b++ {
+							inOffset := m*K + b*256
+							wBlockIdx := n*blocksPerRow + b
+							block := &wData[wBlockIdx]
+							p := &wParams[wBlockIdx]
+							sum += tensor.DotQ6_K_Params(block, p, inData[inOffset:inOffset+256])
+						}
+						outData[m*N+n] = sum
+					}
+				}
+			}(start, end)
+		}
+		wg.Wait()
+
+	case tensor.Q2_K:
+		wData := weight.DataQ2_K()
+		if K%256 != 0 {
+			return fmt.Errorf("linear: Q2_K weight K dimension %d must be multiple of 256", K)
+		}
+		wParams := tensor.GetQ2KDotParams(wData)
+
+		blocksPerRow := K / 256
+		work := M * N * K
+		use := chooseWorkers(work, workers)
+		if use == 1 {
+			if M == 1 {
+				q2kGemvDecodeTiled(outData[:N], inData[:K], wData, wParams, N, blocksPerRow, 0, N)
+				return nil
+			}
+			for m := 0; m < M; m++ {
+				for n := 0; n < N; n++ {
+					var sum float32
+					for b := 0; b < blocksPerRow; b++ {
+						inOffset := m*K + b*256
+						wBlockIdx := n*blocksPerRow + b
+						block := &wData[wBlockIdx]
+						p := &wParams[wBlockIdx]
+						sum += tensor.DotQ2_K_Params(block, p, inData[inOffset:inOffset+256])
+					}
+					outData[m*N+n] = sum
+				}
+			}
+			return nil
+		}
+		var wg sync.WaitGroup
+		if M == 1 {
+			for _, r := range chunkRanges(N, use) {
+				wg.Add(1)
+				start, end := r[0], r[1]
+				go func(s, e int) {
+					defer wg.Done()
+					q2kGemvDecodeTiled(outData[:N], inData[:K], wData, wParams, N, blocksPerRow, s, e)
+				}(start, end)
+			}
+			wg.Wait()
+			return nil
+		}
+		if M < use {
+			for _, r := range chunkRanges(N, use) {
+				wg.Add(1)
+				start, end := r[0], r[1]
+				go func(s, e int) {
+					defer wg.Done()
+					for n := s; n < e; n++ {
+						for m := 0; m < M; m++ {
+							var sum float32
+							for b := 0; b < blocksPerRow; b++ {
+								inOffset := m*K + b*256
+								wBlockIdx := n*blocksPerRow + b
+								block := &wData[wBlockIdx]
+								p := &wParams[wBlockIdx]
+								sum += tensor.DotQ2_K_Params(block, p, inData[inOffset:inOffset+256])
+							}
+							outData[m*N+n] = sum
+						}
+					}
+				}(start, end)
+			}
+			wg.Wait()
+			return nil
+		}
+		for _, r := range chunkRanges(M, use) {
+			wg.Add(1)
+			start, end := r[0], r[1]
+			go func(s, e int) {
+				defer wg.Done()
+				for m := s; m < e; m++ {
+					for n := 0; n < N; n++ {
+						var sum float32
+						for b := 0; b < blocksPerRow; b++ {
+							inOffset := m*K + b*256
+							wBlockIdx := n*blocksPerRow + b
+							block := &wData[wBlockIdx]
+							p := &wParams[wBlockIdx]
+							sum += tensor.DotQ2_K_Params(block, p, inData[inOffset:inOffset+256])
+						}
+						outData[m*N+n] = sum
+					}
+				}
+			}(start, end)
+		}
+		wg.Wait()
+
+	default:
+		return fmt.Errorf("linear: unsupported weight dtype %v", weight.DType())
+	}
+
+	return nil
+}
+
+func q4kGemvDecodeTiled(out []float32, x []float32, w []tensor.BlockQ4_K, wp []tensor.Q4KDotParams, N, blocksPerRow, startN, endN int) {
+	const tile = 8
+	for n := startN; n < endN; n += tile {
+		tn := endN - n
+		if tn > tile {
+			tn = tile
+		}
+		var sums [tile]float32
+		for b := 0; b < blocksPerRow; b++ {
+			xBlock := &x[b*256]
+			base := n*blocksPerRow + b
+			tensor.DotQ4KTile8(&sums, w, wp, base, blocksPerRow, xBlock, tn)
+		}
+		for t := 0; t < tn; t++ {
+			out[n+t] = sums[t]
+		}
+	}
+}
+
+func q5kGemvDecodeTiled(out []float32, x []float32, w []tensor.BlockQ5_K, wp []tensor.Q5KDotParams, N, blocksPerRow, startN, endN int) {
+	const tile = 8
+	for n := startN; n < endN; n += tile {
+		tn := endN - n
+		if tn > tile {
+			tn = tile
+		}
+		var sums [tile]float32
+		for b := 0; b < blocksPerRow; b++ {
+			xBlock := &x[b*256]
+			base := n*blocksPerRow + b
+			tensor.DotQ5KTile8(&sums, w, wp, base, blocksPerRow, xBlock, tn)
+		}
+		for t := 0; t < tn; t++ {
+			out[n+t] = sums[t]
+		}
+	}
+}
+
+func q6kGemvDecodeTiled(out []float32, x []float32, w []tensor.BlockQ6_K, wp []tensor.Q6KDotParams, N, blocksPerRow, startN, endN int) {
+	const tile = 8
+	for n := startN; n < endN; n += tile {
+		tn := endN - n
+		if tn > tile {
+			tn = tile
+		}
+		var sums [tile]float32
+		for b := 0; b < blocksPerRow; b++ {
+			xBlock := &x[b*256]
+			base := n*blocksPerRow + b
+			tensor.DotQ6KTile8(&sums, w, wp, base, blocksPerRow, xBlock, tn)
+		}
+		for t := 0; t < tn; t++ {
+			out[n+t] = sums[t]
+		}
+	}
+}
+
+func q3kGemvDecodeTiled(out []float32, x []float32, w []tensor.BlockQ3_K, wp []tensor.Q3KDotParams, N, blocksPerRow, startN, endN int) {
+	const tile = 8
+	for n := startN; n < endN; n += tile {
+		tn := endN - n
+		if tn > tile {
+			tn = tile
+		}
+		var sums [tile]float32
+		for b := 0; b < blocksPerRow; b++ {
+			xBlock := &x[b*256]
+			base := n*blocksPerRow + b
+			tensor.DotQ3KTile8(&sums, w, wp, base, blocksPerRow, xBlock, tn)
+		}
+		for t := 0; t < tn; t++ {
+			out[n+t] = sums[t]
+		}
+	}
+}
+
+func q2kGemvDecodeTiled(out []float32, x []float32, w []tensor.BlockQ2_K, wp []tensor.Q2KDotParams, N, blocksPerRow, startN, endN int) {
+	const tile = 8
+	for n := startN; n < endN; n += tile {
+		tn := endN - n
+		if tn > tile {
+			tn = tile
+		}
+		var sums [tile]float32
+		for b := 0; b < blocksPerRow; b++ {
+			xBlock := &x[b*256]
+			base := n*blocksPerRow + b
+			tensor.DotQ2KTile8(&sums, w, wp, base, blocksPerRow, xBlock, tn)
+		}
+		for t := 0; t < tn; t++ {
+			out[n+t] = sums[t]
+		}
+	}
+}
+
+func q8kGemvDecodeTiled(out []float32, x []float32, w []tensor.BlockQ8_K, N, blocksPerRow, startN, endN int) {
+	const tile = 8
+	for n := startN; n < endN; n += tile {
+		tn := endN - n
+		if tn > tile {
+			tn = tile
+		}
+		var sums [tile]float32
+		for b := 0; b < blocksPerRow; b++ {
+			xBlock := &x[b*256]
+			base := n*blocksPerRow + b
+			tensor.DotQ8KTile8(&sums, w, base, blocksPerRow, xBlock, tn)
+		}
+		for t := 0; t < tn; t++ {
+			out[n+t] = sums[t]
+		}
+	}
+}

+ 94 - 0
pkg/backend/cpu/nn/activations.go

@@ -0,0 +1,94 @@
+package nn
+
+import (
+	"math"
+
+	"makarna/pkg/backend/cpu"
+)
+
+// Sigmoid applies the sigmoid function: 1 / (1 + exp(-x))
+// Uses numerically stable computation for both positive and negative inputs.
+func Sigmoid(x float32) float32 {
+	if x >= 0 {
+		z := float32(math.Exp(float64(-x)))
+		return 1 / (1 + z)
+	}
+	z := float32(math.Exp(float64(x)))
+	return z / (1 + z)
+}
+
+// SigmoidInplace applies sigmoid to each element of the slice in-place.
+func SigmoidInplace(data []float32) {
+	for i := range data {
+		data[i] = Sigmoid(data[i])
+	}
+}
+
+// SigmoidTensor applies sigmoid in-place to a tensor.
+func SigmoidTensor(x *cpu.Tensor) error {
+	SigmoidInplace(x.DataFloat32())
+	return nil
+}
+
+// Softplus computes log(1 + exp(x)), with numerical stability for large x.
+func Softplus(x float32) float32 {
+	if x > 20 {
+		return x
+	}
+	return float32(math.Log1p(math.Exp(float64(x))))
+}
+
+// SoftplusInplace applies softplus to each element of the slice in-place.
+func SoftplusInplace(data []float32) {
+	for i := range data {
+		data[i] = Softplus(data[i])
+	}
+}
+
+// SoftplusTensor applies softplus in-place to a tensor.
+func SoftplusTensor(x *cpu.Tensor) error {
+	SoftplusInplace(x.DataFloat32())
+	return nil
+}
+
+// L2NormInplace normalizes a vector to unit length in-place.
+// eps is added to prevent division by zero.
+func L2NormInplace(x []float32, eps float32) {
+	if len(x) == 0 {
+		return
+	}
+	ss := float32(0)
+	for _, v := range x {
+		ss += v * v
+	}
+	inv := float32(1.0 / math.Sqrt(float64(ss+eps)))
+	for i := range x {
+		x[i] *= inv
+	}
+}
+
+// L2NormHeads normalizes Q and K vectors per-head for multi-head attention.
+// Shape: [tokens, numHeads * headDim], normalizes each [headDim] segment.
+func L2NormHeads(q, k []float32, tokens, numHeads, headDim int, eps float32) {
+	stride := numHeads * headDim
+	for t := 0; t < tokens; t++ {
+		base := t * stride
+		for h := 0; h < numHeads; h++ {
+			off := base + h*headDim
+			L2NormInplace(q[off:off+headDim], eps)
+			L2NormInplace(k[off:off+headDim], eps)
+		}
+	}
+}
+
+// Exp computes e^x
+func Exp(x float32) float32 {
+	return float32(math.Exp(float64(x)))
+}
+
+// ExpInplace applies exp to each element of the slice in-place.
+func ExpInplace(data []float32) {
+	for i := range data {
+		data[i] = Exp(data[i])
+	}
+}

+ 164 - 0
pkg/backend/cpu/nn/attention.go

@@ -0,0 +1,164 @@
+package nn
+
+import (
+	"math"
+	"sync"
+
+	"makarna/pkg/backend/cpu"
+)
+
+// CausalAttention computes causal (masked) scaled dot-product attention with GQA support
+// Q: [seq_len, num_heads * head_dim]
+// K: [seq_len, num_kv_heads * head_dim]
+// V: [seq_len, num_kv_heads * head_dim]
+// Output: [seq_len, num_heads * head_dim]
+//
+// For Grouped Query Attention (GQA): num_heads is a multiple of num_kv_heads
+// Each KV head is shared across (num_heads / num_kv_heads) query heads
+func CausalAttention(q, k, v, output *cpu.Tensor, numHeads, numKVHeads, headDim int) error {
+	seqLen := q.Shape()[0]
+
+	qData := q.DataFloat32()
+	kData := k.DataFloat32()
+	vData := v.DataFloat32()
+	outData := output.DataFloat32()
+
+	scale := 1.0 / math.Sqrt(float64(headDim))
+
+	// Number of Q heads per KV head (for GQA)
+	groupSize := numHeads / numKVHeads
+	workers := cpu.MaxThreads()
+	if workers < 2 || numHeads < 2 {
+		runCausalHeads(qData, kData, vData, outData, numHeads, numKVHeads, headDim, groupSize, seqLen, scale, 0, numHeads)
+		return nil
+	}
+
+	chunk := (numHeads + workers - 1) / workers
+	var wg sync.WaitGroup
+	for start := 0; start < numHeads; start += chunk {
+		end := start + chunk
+		if end > numHeads {
+			end = numHeads
+		}
+		wg.Add(1)
+		go func(s, e int) {
+			defer wg.Done()
+			runCausalHeads(qData, kData, vData, outData, numHeads, numKVHeads, headDim, groupSize, seqLen, scale, s, e)
+		}(start, end)
+	}
+	wg.Wait()
+
+	return nil
+}
+
+func runCausalHeads(qData, kData, vData, outData []float32, numHeads, numKVHeads, headDim, groupSize, seqLen int, scale float64, hStart, hEnd int) {
+	// Reuse a per-worker buffer to avoid per-token allocations.
+	// max numKeys is seqLen.
+	scoresBuf := make([]float32, seqLen)
+	strideQ := numHeads * headDim
+	strideKV := numKVHeads * headDim
+
+	for h := hStart; h < hEnd; h++ {
+		qHeadOffset := h * headDim
+		kvHead := h / groupSize
+		kvHeadOffset := kvHead * headDim
+
+		for qi := 0; qi < seqLen; qi++ {
+			numKeys := qi + 1
+
+			scores := scoresBuf[:numKeys]
+			qBase := qi*strideQ + qHeadOffset
+			qVec := qData[qBase : qBase+headDim]
+			for ki := 0; ki < numKeys; ki++ {
+				kBase := ki*strideKV + kvHeadOffset
+				kVec := kData[kBase : kBase+headDim]
+				dot := cpu.DotFloat32(qVec, kVec)
+				scores[ki] = dot * float32(scale)
+			}
+
+			softmaxInplace(scores)
+
+			outBase := qi*strideQ + qHeadOffset
+			outVec := outData[outBase : outBase+headDim]
+			clear(outVec)
+			for vi := 0; vi < numKeys; vi++ {
+				alpha := scores[vi]
+				vBase := vi*strideKV + kvHeadOffset
+				vVec := vData[vBase : vBase+headDim]
+				cpu.Axpy(alpha, vVec, outVec)
+			}
+		}
+	}
+}
+
+// Attention computes full (non-causal) attention - for encoder models
+func Attention(q, k, v, output *cpu.Tensor, numHeads, numKVHeads, headDim int) error {
+	seqLen := q.Shape()[0]
+
+	qData := q.DataFloat32()
+	kData := k.DataFloat32()
+	vData := v.DataFloat32()
+	outData := output.DataFloat32()
+
+	scale := 1.0 / math.Sqrt(float64(headDim))
+	groupSize := numHeads / numKVHeads
+	workers := cpu.MaxThreads()
+	if workers < 2 || numHeads < 2 {
+		runFullHeads(qData, kData, vData, outData, numHeads, numKVHeads, headDim, groupSize, seqLen, scale, 0, numHeads)
+		return nil
+	}
+
+	chunk := (numHeads + workers - 1) / workers
+	var wg sync.WaitGroup
+	for start := 0; start < numHeads; start += chunk {
+		end := start + chunk
+		if end > numHeads {
+			end = numHeads
+		}
+		wg.Add(1)
+		go func(s, e int) {
+			defer wg.Done()
+			runFullHeads(qData, kData, vData, outData, numHeads, numKVHeads, headDim, groupSize, seqLen, scale, s, e)
+		}(start, end)
+	}
+	wg.Wait()
+
+	return nil
+}
+
+func runFullHeads(qData, kData, vData, outData []float32, numHeads, numKVHeads, headDim, groupSize, seqLen int, scale float64, hStart, hEnd int) {
+	// Reuse a per-worker buffer to avoid per-token allocations.
+	scores := make([]float32, seqLen)
+	strideQ := numHeads * headDim
+	strideKV := numKVHeads * headDim
+
+	for h := hStart; h < hEnd; h++ {
+		qHeadOffset := h * headDim
+		kvHead := h / groupSize
+		kvHeadOffset := kvHead * headDim
+
+		for qi := 0; qi < seqLen; qi++ {
+
+			qBase := qi*strideQ + qHeadOffset
+			qVec := qData[qBase : qBase+headDim]
+			for ki := 0; ki < seqLen; ki++ {
+				kBase := ki*strideKV + kvHeadOffset
+				kVec := kData[kBase : kBase+headDim]
+				dot := cpu.DotFloat32(qVec, kVec)
+				scores[ki] = dot * float32(scale)
+			}
+
+			softmaxInplace(scores)
+
+			outBase := qi*strideQ + qHeadOffset
+			outVec := outData[outBase : outBase+headDim]
+			clear(outVec)
+			for vi := 0; vi < seqLen; vi++ {
+				alpha := scores[vi]
+				vBase := vi*strideKV + kvHeadOffset
+				vVec := vData[vBase : vBase+headDim]
+				cpu.Axpy(alpha, vVec, outVec)
+			}
+		}
+	}
+}

+ 192 - 0
pkg/backend/cpu/nn/attention_batch.go

@@ -0,0 +1,192 @@
+package nn
+
+import (
+	"fmt"
+	"math"
+	"sync"
+
+	"makarna/pkg/backend/cpu"
+	"makarna/pkg/kvcache"
+)
+
+// CausalAttentionPackedBlocksBatch computes causal attention for a batch where each
+// row in q/output is an independent sequence (decode-style batch).
+//
+// q:      [numTokens, numHeads*headDim] (float32, CPU)
+// views:  per-token packed KV views (head-major blocks)
+// output: [numTokens, numHeads*headDim] (float32, CPU)
+// queryPos: per-token absolute query position (startPos for that token).
+//
+// This is optimized for CPU KV caches and uses SIMD-backed Dot/Axpy when available.
+func CausalAttentionPackedBlocksBatch(
+	q *cpu.Tensor,
+	viewsByToken [][]kvcache.PackedView,
+	output *cpu.Tensor,
+	numHeads, numKVHeads, headDim int,
+	queryPos []int,
+) error {
+	if q == nil || output == nil {
+		return fmt.Errorf("nil tensor")
+	}
+	if numHeads <= 0 || numKVHeads <= 0 || headDim <= 0 {
+		return fmt.Errorf("invalid heads/dim: numHeads=%d numKVHeads=%d headDim=%d", numHeads, numKVHeads, headDim)
+	}
+	if numHeads%numKVHeads != 0 {
+		return fmt.Errorf("numHeads %d not divisible by numKVHeads %d", numHeads, numKVHeads)
+	}
+
+	qShape := q.Shape()
+	outShape := output.Shape()
+	if len(qShape) != 2 || len(outShape) != 2 {
+		return fmt.Errorf("expected 2D tensors (q=%v out=%v)", qShape, outShape)
+	}
+	numTokens := qShape[0]
+	qStride := qShape[1]
+	if numTokens <= 0 {
+		return nil
+	}
+	if outShape[0] != numTokens || outShape[1] != qStride {
+		return fmt.Errorf("output shape mismatch: q=%v out=%v", qShape, outShape)
+	}
+	if qStride != numHeads*headDim {
+		return fmt.Errorf("q stride mismatch: got %d want %d (numHeads*headDim)", qStride, numHeads*headDim)
+	}
+	if len(viewsByToken) != numTokens {
+		return fmt.Errorf("viewsByToken len %d != numTokens %d", len(viewsByToken), numTokens)
+	}
+	if len(queryPos) != numTokens {
+		return fmt.Errorf("queryPos len %d != numTokens %d", len(queryPos), numTokens)
+	}
+
+	qData := q.DataFloat32()
+	outData := output.DataFloat32()
+
+	scale := float32(1.0 / math.Sqrt(float64(headDim)))
+	groupSize := numHeads / numKVHeads
+
+	workItems := numTokens * numHeads
+	workers := cpu.MaxThreads()
+	if workers < 2 || workItems < 2 {
+		runPackedBatchWork(qData, outData, viewsByToken, queryPos, numTokens, numHeads, numKVHeads, headDim, groupSize, qStride, scale, 0, workItems)
+		return nil
+	}
+
+	if workers > workItems {
+		workers = workItems
+	}
+
+	chunk := (workItems + workers - 1) / workers
+	var wg sync.WaitGroup
+	wg.Add(workers)
+	for w := 0; w < workers; w++ {
+		start := w * chunk
+		end := start + chunk
+		if end > workItems {
+			end = workItems
+		}
+		go func(s, e int) {
+			defer wg.Done()
+			runPackedBatchWork(qData, outData, viewsByToken, queryPos, numTokens, numHeads, numKVHeads, headDim, groupSize, qStride, scale, s, e)
+		}(start, end)
+	}
+	wg.Wait()
+	return nil
+}
+
+func runPackedBatchWork(
+	qData, outData []float32,
+	viewsByToken [][]kvcache.PackedView,
+	queryPos []int,
+	numTokens, numHeads, numKVHeads, headDim, groupSize, qStride int,
+	scale float32,
+	workStart, workEnd int,
+) {
+	for idx := workStart; idx < workEnd; idx++ {
+		tok := idx / numHeads
+		head := idx - tok*numHeads
+
+		if tok < 0 || tok >= numTokens {
+			continue
+		}
+
+		qHeadOffset := head * headDim
+		qBase := tok*qStride + qHeadOffset
+		if qBase < 0 || qBase+headDim > len(qData) {
+			continue
+		}
+
+		outBase := tok*qStride + qHeadOffset
+		if outBase < 0 || outBase+headDim > len(outData) {
+			continue
+		}
+
+		qPtr := &qData[qBase]
+		outVec := outData[outBase : outBase+headDim]
+		outPtr := &outData[outBase]
+		clear(outVec)
+
+		kvHead := head / groupSize
+		maxKeyPos := queryPos[tok] + 1
+		if maxKeyPos <= 0 {
+			continue
+		}
+
+		m := float32(-math.MaxFloat32)
+		l := float32(0)
+
+		for _, pv := range viewsByToken[tok] {
+			if pv.Length == 0 || pv.Start >= maxKeyPos {
+				continue
+			}
+			if pv.HeadDim != headDim || pv.NumKVHeads != numKVHeads {
+				continue
+			}
+			blkStride := pv.BlockSize * headDim
+			headBase := kvHead * blkStride
+			if blkStride <= 0 || headBase < 0 || headBase+blkStride > len(pv.K) || headBase+blkStride > len(pv.V) {
+				continue
+			}
+			viewLimit := pv.Length
+			if pv.Start+viewLimit > maxKeyPos {
+				viewLimit = maxKeyPos - pv.Start
+			}
+			if viewLimit <= 0 {
+				continue
+			}
+
+			kHead := pv.K[headBase : headBase+blkStride]
+			vHead := pv.V[headBase : headBase+blkStride]
+			for t := 0; t < viewLimit; t++ {
+				kPtr := &kHead[t*headDim]
+				s := cpu.DotFloat32Ptr(qPtr, kPtr, headDim) * scale
+				vPtr := &vHead[t*headDim]
+
+				if s > m {
+					alpha := expf(m - s)
+					if l != 0 {
+						for i := 0; i < headDim; i++ {
+							outVec[i] *= alpha
+						}
+						l *= alpha
+					}
+					m = s
+					l += 1
+					cpu.AxpyPtr(1, vPtr, outPtr, headDim)
+					continue
+				}
+
+				w := expf(s - m)
+				l += w
+				cpu.AxpyPtr(w, vPtr, outPtr, headDim)
+			}
+		}
+
+		if l != 0 {
+			inv := 1 / l
+			for i := 0; i < headDim; i++ {
+				outVec[i] *= inv
+			}
+		}
+	}
+}
+

+ 102 - 0
pkg/backend/cpu/nn/attention_batch_test.go

@@ -0,0 +1,102 @@
+package nn
+
+import (
+	"math"
+	"math/rand"
+	"testing"
+
+	"makarna/pkg/backend/cpu"
+	"makarna/pkg/kvcache"
+	"makarna/pkg/tensor"
+)
+
+func TestCausalAttentionPackedBlocksBatch_MatchesPerToken(t *testing.T) {
+	numTokens := 7
+	numHeads := 8
+	numKVHeads := 2
+	headDim := 16
+	blockSize := 4
+
+	qStride := numHeads * headDim
+
+	queryPos := []int{0, 3, 5, 1, 7, 2, 9}
+	if len(queryPos) != numTokens {
+		t.Fatalf("test bug: queryPos len %d != numTokens %d", len(queryPos), numTokens)
+	}
+
+	rng := rand.New(rand.NewSource(1234))
+
+	viewsByToken := make([][]kvcache.PackedView, numTokens)
+	for tok := 0; tok < numTokens; tok++ {
+		kvLen := queryPos[tok] + 1
+		if kvLen <= 0 {
+			t.Fatalf("kvLen must be > 0, got %d", kvLen)
+		}
+		numBlocks := (kvLen + blockSize - 1) / blockSize
+		views := make([]kvcache.PackedView, 0, numBlocks)
+		for b := 0; b < numBlocks; b++ {
+			start := b * blockSize
+			length := blockSize
+			if start+length > kvLen {
+				length = kvLen - start
+			}
+			if length <= 0 {
+				break
+			}
+
+			blkStride := blockSize * headDim
+			k := make([]float32, numKVHeads*blkStride)
+			v := make([]float32, numKVHeads*blkStride)
+			for i := range k {
+				k[i] = rng.Float32()*2 - 1
+				v[i] = rng.Float32()*2 - 1
+			}
+
+			views = append(views, kvcache.PackedView{
+				K:          k,
+				V:          v,
+				Start:      start,
+				Length:     length,
+				BlockSize:  blockSize,
+				HeadDim:    headDim,
+				NumKVHeads: numKVHeads,
+			})
+		}
+		viewsByToken[tok] = views
+	}
+
+	q := cpu.NewTensor(tensor.Shape{numTokens, qStride}, nil)
+	qData := q.DataFloat32()
+	for i := range qData {
+		qData[i] = rng.Float32()*2 - 1
+	}
+
+	outBatch := cpu.NewTensor(tensor.Shape{numTokens, qStride}, nil)
+	if err := CausalAttentionPackedBlocksBatch(q, viewsByToken, outBatch, numHeads, numKVHeads, headDim, queryPos); err != nil {
+		t.Fatalf("CausalAttentionPackedBlocksBatch: %v", err)
+	}
+
+	outRef := cpu.NewTensor(tensor.Shape{numTokens, qStride}, nil)
+	outRefData := outRef.DataFloat32()
+	for tok := 0; tok < numTokens; tok++ {
+		qRow := cpu.NewTensor(tensor.Shape{1, qStride}, qData[tok*qStride:(tok+1)*qStride])
+		outRow := cpu.NewTensor(tensor.Shape{1, qStride}, outRefData[tok*qStride:(tok+1)*qStride])
+		if err := CausalAttentionPackedBlocks(qRow, viewsByToken[tok], outRow, numHeads, numKVHeads, headDim, queryPos[tok]); err != nil {
+			t.Fatalf("CausalAttentionPackedBlocks tok=%d: %v", tok, err)
+		}
+	}
+
+	got := outBatch.DataFloat32()
+	want := outRef.DataFloat32()
+	if len(got) != len(want) {
+		t.Fatalf("length mismatch: got %d want %d", len(got), len(want))
+	}
+
+	const tol = 1e-4
+	for i := range got {
+		if diff := math.Abs(float64(got[i] - want[i])); diff > tol {
+			t.Fatalf("mismatch at %d: got=%g want=%g diff=%g", i, got[i], want[i], diff)
+		}
+	}
+}
+

+ 454 - 0
pkg/backend/cpu/nn/attention_cached.go

@@ -0,0 +1,454 @@
+package nn
+
+import (
+	"fmt"
+	"math"
+	"sort"
+	"sync"
+
+	"makarna/pkg/backend/cpu"
+	"makarna/pkg/backend/cuda"
+	"makarna/pkg/kvcache"
+	"makarna/pkg/tensor"
+)
+
+var useFastExp = true
+
+func float16BitsToFloat32(bits uint16) float32 {
+	sign := uint32(bits&0x8000) << 16
+	exp := int32((bits & 0x7C00) >> 10)
+	mant := uint32(bits & 0x03FF)
+
+	if exp == 0 {
+		if mant == 0 {
+			return math.Float32frombits(sign)
+		}
+		for mant&0x0400 == 0 {
+			mant <<= 1
+			exp--
+		}
+		exp++
+		mant &= 0x03FF
+	} else if exp == 0x1F {
+		if mant == 0 {
+			return math.Float32frombits(sign | 0x7F800000)
+		}
+		return math.Float32frombits(sign | 0x7FC00000)
+	}
+
+	exp = exp + (127 - 15)
+	return math.Float32frombits(sign | (uint32(exp) << 23) | (mant << 13))
+}
+
+func bfloat16BitsToFloat32(bits uint16) float32 {
+	return math.Float32frombits(uint32(bits) << 16)
+}
+
+func expf(x float32) float32 {
+	if useFastExp {
+		// Clamp to a reasonable range for stability.
+		// For softmax weights, very negative values underflow to ~0 anyway.
+		if x < -20 {
+			x = -20
+		} else if x > 10 {
+			x = 10
+		}
+		// Schraudolph-style fast exp approximation.
+		// Good tradeoff for softmax weights; much faster than math.Exp.
+		const a = 12102203.0 // (1<<23)/ln(2)
+		const b = 1065353216.0
+		return math.Float32frombits(uint32(float32(a)*x + float32(b)))
+	}
+	return float32(math.Exp(float64(x)))
+}
+
+type viewData struct {
+	kData  []float32
+	vData  []float32
+	start  int
+	length int
+}
+
+// CausalAttentionCached computes causal attention using cached K/V
+// Q: [newTokens, numHeads * headDim] - query for new tokens only
+// K: [totalSeqLen, numKVHeads * headDim] - full K history including current
+// V: [totalSeqLen, numKVHeads * headDim] - full V history including current
+// Output: [newTokens, numHeads * headDim]
+// startPos: position of first new token in sequence
+func CausalAttentionCached(q, k, v, output *cpu.Tensor, numHeads, numKVHeads, headDim, startPos int) error {
+	newTokens := q.Shape()[0]
+	totalSeqLen := k.Shape()[0]
+
+	qData := q.DataFloat32()
+	kData := k.DataFloat32()
+	vData := v.DataFloat32()
+	outData := output.DataFloat32()
+
+	scale := 1.0 / math.Sqrt(float64(headDim))
+	groupSize := numHeads / numKVHeads
+
+	workers := cpu.MaxThreads()
+	if workers < 2 || numHeads < 2 {
+		runCausalCachedHeads(qData, kData, vData, outData, numHeads, numKVHeads, headDim, groupSize, newTokens, totalSeqLen, startPos, scale, 0, numHeads)
+		return nil
+	}
+
+	chunk := (numHeads + workers - 1) / workers
+	var wg sync.WaitGroup
+	for start := 0; start < numHeads; start += chunk {
+		end := start + chunk
+		if end > numHeads {
+			end = numHeads
+		}
+		wg.Add(1)
+		go func(s, e int) {
+			defer wg.Done()
+			runCausalCachedHeads(qData, kData, vData, outData, numHeads, numKVHeads, headDim, groupSize, newTokens, totalSeqLen, startPos, scale, s, e)
+		}(start, end)
+	}
+	wg.Wait()
+
+	return nil
+}
+
+// CausalAttentionPackedBlocks computes causal attention over packed KV views.
+// Packed layout is head-major: [kvHead][tokenWithinBlock][headDim] as a flat slice.
+// This avoids kvDim-stride traversal and is a fast CPU path.
+func CausalAttentionPackedBlocks(
+	q *cpu.Tensor,
+	views []kvcache.PackedView,
+	output *cpu.Tensor,
+	numHeads, numKVHeads, headDim, startPos int,
+) error {
+	newTokens := q.Shape()[0]
+	qData := q.DataFloat32()
+	outData := output.DataFloat32()
+
+	scale := 1.0 / math.Sqrt(float64(headDim))
+	groupSize := numHeads / numKVHeads
+
+	// Sort to guarantee increasing start positions.
+	sort.Slice(views, func(i, j int) bool {
+		return views[i].Start < views[j].Start
+	})
+
+	workers := cpu.MaxThreads()
+	if workers < 2 || numHeads < 2 {
+		runCausalPackedHeads(qData, outData, views, numHeads, numKVHeads, headDim, groupSize, newTokens, startPos, scale, 0, numHeads)
+		return nil
+	}
+
+	chunk := (numHeads + workers - 1) / workers
+	var wg sync.WaitGroup
+	for start := 0; start < numHeads; start += chunk {
+		end := start + chunk
+		if end > numHeads {
+			end = numHeads
+		}
+		wg.Add(1)
+		go func(s, e int) {
+			defer wg.Done()
+			runCausalPackedHeads(qData, outData, views, numHeads, numKVHeads, headDim, groupSize, newTokens, startPos, scale, s, e)
+		}(start, end)
+	}
+	wg.Wait()
+
+	return nil
+}
+
+func runCausalCachedHeads(qData, kData, vData, outData []float32, numHeads, numKVHeads, headDim, groupSize, newTokens, totalSeqLen, startPos int, scale float64, hStart, hEnd int) {
+	strideQ := numHeads * headDim
+	strideKV := numKVHeads * headDim
+
+	for h := hStart; h < hEnd; h++ {
+		qHeadOffset := h * headDim
+		kvHead := h / groupSize
+		kvHeadOffset := kvHead * headDim
+
+		for qi := 0; qi < newTokens; qi++ {
+			maxKeyPos := startPos + qi + 1
+			if maxKeyPos > totalSeqLen {
+				maxKeyPos = totalSeqLen
+			}
+			qBase := qi*strideQ + qHeadOffset
+			qPtr := &qData[qBase]
+
+			outBase := qi*strideQ + qHeadOffset
+			outVec := outData[outBase : outBase+headDim]
+			outPtr := &outData[outBase]
+			clear(outVec)
+
+			m := float32(-math.MaxFloat32)
+			l := float32(0)
+			for ti := 0; ti < maxKeyPos; ti++ {
+				kBase := ti*strideKV + kvHeadOffset
+				kPtr := &kData[kBase]
+				s := cpu.DotFloat32Ptr(qPtr, kPtr, headDim) * float32(scale)
+				vBase := ti*strideKV + kvHeadOffset
+				vPtr := &vData[vBase]
+
+				if s > m {
+					alpha := expf(m - s)
+					if l != 0 {
+						for i := 0; i < headDim; i++ {
+							outVec[i] *= alpha
+						}
+						l *= alpha
+					}
+					m = s
+					l += 1
+					cpu.AxpyPtr(1, vPtr, outPtr, headDim)
+					continue
+				}
+
+				w := expf(s - m)
+				l += w
+				cpu.AxpyPtr(w, vPtr, outPtr, headDim)
+			}
+			if l != 0 {
+				inv := 1 / l
+				for i := 0; i < headDim; i++ {
+					outVec[i] *= inv
+				}
+			}
+		}
+	}
+}
+
+func runCausalPackedHeads(qData, outData []float32, views []kvcache.PackedView, numHeads, numKVHeads, headDim, groupSize, newTokens, startPos int, scale float64, hStart, hEnd int) {
+	strideQ := numHeads * headDim
+
+	for h := hStart; h < hEnd; h++ {
+		qHeadOffset := h * headDim
+		kvHead := h / groupSize
+
+		for qi := 0; qi < newTokens; qi++ {
+			maxKeyPos := startPos + qi + 1
+			qBase := qi*strideQ + qHeadOffset
+			qPtr := &qData[qBase]
+
+			outBase := qi*strideQ + qHeadOffset
+			outVec := outData[outBase : outBase+headDim]
+			outPtr := &outData[outBase]
+			clear(outVec)
+
+			m := float32(-math.MaxFloat32)
+			l := float32(0)
+			for _, pv := range views {
+				if pv.Length == 0 || pv.Start >= maxKeyPos {
+					continue
+				}
+				if pv.HeadDim != headDim || pv.NumKVHeads != numKVHeads {
+					continue
+				}
+				blkStride := pv.BlockSize * headDim
+				headBase := kvHead * blkStride
+				if headBase+blkStride > len(pv.K) || headBase+blkStride > len(pv.V) {
+					continue
+				}
+				viewLimit := pv.Length
+				if pv.Start+viewLimit > maxKeyPos {
+					viewLimit = maxKeyPos - pv.Start
+				}
+				kHead := pv.K[headBase : headBase+blkStride]
+				vHead := pv.V[headBase : headBase+blkStride]
+				for t := 0; t < viewLimit; t++ {
+					kPtr := &kHead[t*headDim]
+					s := cpu.DotFloat32Ptr(qPtr, kPtr, headDim) * float32(scale)
+					vPtr := &vHead[t*headDim]
+
+					if s > m {
+						alpha := expf(m - s)
+						if l != 0 {
+							for i := 0; i < headDim; i++ {
+								outVec[i] *= alpha
+							}
+							l *= alpha
+						}
+						m = s
+						l += 1
+						cpu.AxpyPtr(1, vPtr, outPtr, headDim)
+						continue
+					}
+
+					w := expf(s - m)
+					l += w
+					cpu.AxpyPtr(w, vPtr, outPtr, headDim)
+				}
+			}
+			if l != 0 {
+				inv := 1 / l
+				for i := 0; i < headDim; i++ {
+					outVec[i] *= inv
+				}
+			}
+		}
+	}
+}
+
+// CausalAttentionBlocks computes attention directly over KV block views without
+// materializing a contiguous history tensor. startPos is the absolute position
+// of the first new token (current cache length before the append).
+func CausalAttentionBlocks(
+	q *cpu.Tensor,
+	views []kvcache.View,
+	output *cpu.Tensor,
+	numHeads, numKVHeads, headDim, startPos int,
+) error {
+	newTokens := q.Shape()[0]
+	qData := q.DataFloat32()
+	outData := output.DataFloat32()
+
+	scale := 1.0 / math.Sqrt(float64(headDim))
+	groupSize := numHeads / numKVHeads
+
+	// Pre-extract data from all views (handles CPU and GPU tensors)
+	viewsData := make([]viewData, len(views))
+	for i, v := range views {
+		if v.Length == 0 {
+			continue
+		}
+		kData, err := tensorToFloat32(v.K)
+		if err != nil {
+			return fmt.Errorf("failed to get K data from view: %w", err)
+		}
+		vData, err := tensorToFloat32(v.V)
+		if err != nil {
+			return fmt.Errorf("failed to get V data from view: %w", err)
+		}
+		viewsData[i] = viewData{
+			kData:  kData,
+			vData:  vData,
+			start:  v.Start,
+			length: v.Length,
+		}
+	}
+	sort.Slice(viewsData, func(i, j int) bool {
+		return viewsData[i].start < viewsData[j].start
+	})
+
+
+	workers := cpu.MaxThreads()
+	if workers < 2 || numHeads < 2 {
+		runCausalBlockHeads(qData, outData, viewsData, numHeads, numKVHeads, headDim, groupSize, newTokens, startPos, scale, 0, numHeads)
+		return nil
+	}
+
+	chunk := (numHeads + workers - 1) / workers
+	var wg sync.WaitGroup
+	for start := 0; start < numHeads; start += chunk {
+		end := start + chunk
+		if end > numHeads {
+			end = numHeads
+		}
+		wg.Add(1)
+		go func(s, e int) {
+			defer wg.Done()
+			runCausalBlockHeads(qData, outData, viewsData, numHeads, numKVHeads, headDim, groupSize, newTokens, startPos, scale, s, e)
+		}(start, end)
+	}
+	wg.Wait()
+
+	return nil
+}
+
+func runCausalBlockHeads(qData, outData []float32, viewsData []viewData, numHeads, numKVHeads, headDim, groupSize, newTokens, startPos int, scale float64, hStart, hEnd int) {
+	strideQ := numHeads * headDim
+	strideKV := numKVHeads * headDim
+
+	for h := hStart; h < hEnd; h++ {
+		qHeadOffset := h * headDim
+		kvHead := h / groupSize
+		kvHeadOffset := kvHead * headDim
+
+		for qi := 0; qi < newTokens; qi++ {
+			maxKeyPos := startPos + qi + 1
+			qBase := qi*strideQ + qHeadOffset
+			qVec := qData[qBase : qBase+headDim]
+
+			outBase := qi*strideQ + qHeadOffset
+			outVec := outData[outBase : outBase+headDim]
+			clear(outVec)
+
+			m := float32(-math.MaxFloat32)
+			l := float32(0)
+			for _, vd := range viewsData {
+				if vd.start >= maxKeyPos || vd.length == 0 {
+					continue
+				}
+				viewLimit := vd.length
+				if vd.start+viewLimit > maxKeyPos {
+					viewLimit = maxKeyPos - vd.start
+				}
+				for local := 0; local < viewLimit; local++ {
+					kvIdx := local*strideKV + kvHeadOffset
+					kVec := vd.kData[kvIdx : kvIdx+headDim]
+					s := cpu.DotFloat32(qVec, kVec) * float32(scale)
+					vVec := vd.vData[kvIdx : kvIdx+headDim]
+
+					if s > m {
+						alpha := expf(m - s)
+						if l != 0 {
+							for i := 0; i < headDim; i++ {
+								outVec[i] *= alpha
+							}
+							l *= alpha
+						}
+						m = s
+						l += 1
+						cpu.Axpy(1, vVec, outVec)
+						continue
+					}
+
+					w := expf(s - m)
+					l += w
+					cpu.Axpy(w, vVec, outVec)
+				}
+			}
+			if l != 0 {
+				inv := 1 / l
+				for i := 0; i < headDim; i++ {
+					outVec[i] *= inv
+				}
+			}
+		}
+	}
+}
+
+// tensorToFloat32 extracts float32 data from a tensor, handling both CPU and CUDA tensors.
+func tensorToFloat32(t tensor.Tensor) ([]float32, error) {
+	switch tt := t.(type) {
+	case *cpu.Tensor:
+		switch tt.DType() {
+		case tensor.Float32:
+			return tt.DataFloat32(), nil
+		case tensor.Float16:
+			in := tt.DataUint16()
+			out := make([]float32, len(in))
+			for i := range in {
+				out[i] = float16BitsToFloat32(in[i])
+			}
+			return out, nil
+		case tensor.BFloat16:
+			in := tt.DataUint16()
+			out := make([]float32, len(in))
+			for i := range in {
+				out[i] = bfloat16BitsToFloat32(in[i])
+			}
+			return out, nil
+		default:
+			return nil, fmt.Errorf("unsupported CPU tensor dtype: %v", tt.DType())
+		}
+	case *cuda.Tensor:
+		data := make([]float32, t.Shape().NumElements())
+		if err := tt.CopyToHost(data); err != nil {
+			return nil, err
+		}
+		return data, nil
+	default:
+		return nil, fmt.Errorf("unsupported tensor type: %T", t)
+	}
+}
+
+func cpuDevice() tensor.DeviceType { return tensor.CPU }
+

+ 108 - 0
pkg/backend/cpu/nn/attention_cached_kv.go

@@ -0,0 +1,108 @@
+package nn
+
+import (
+	"math"
+	"sync"
+
+	"makarna/pkg/backend/cpu"
+)
+
+func CausalAttentionCachedKV(q, k, v, output *cpu.Tensor, numHeads, numKVHeads, headDimK, headDimV, startPos int) error {
+	newTokens := q.Shape()[0]
+	totalSeqLen := k.Shape()[0]
+
+	qData := q.DataFloat32()
+	kData := k.DataFloat32()
+	vData := v.DataFloat32()
+	outData := output.DataFloat32()
+
+	scale := 1.0 / math.Sqrt(float64(headDimK))
+	groupSize := numHeads / numKVHeads
+
+	workers := cpu.MaxThreads()
+	if workers < 2 || numHeads < 2 {
+		runCausalCachedHeadsKV(qData, kData, vData, outData, numHeads, numKVHeads, headDimK, headDimV, groupSize, newTokens, totalSeqLen, startPos, scale, 0, numHeads)
+		return nil
+	}
+
+	chunk := (numHeads + workers - 1) / workers
+	var wg sync.WaitGroup
+	for start := 0; start < numHeads; start += chunk {
+		end := start + chunk
+		if end > numHeads {
+			end = numHeads
+		}
+		wg.Add(1)
+		go func(s, e int) {
+			defer wg.Done()
+			runCausalCachedHeadsKV(qData, kData, vData, outData, numHeads, numKVHeads, headDimK, headDimV, groupSize, newTokens, totalSeqLen, startPos, scale, s, e)
+		}(start, end)
+	}
+	wg.Wait()
+
+	return nil
+}
+
+func runCausalCachedHeadsKV(qData, kData, vData, outData []float32, numHeads, numKVHeads, headDimK, headDimV, groupSize, newTokens, totalSeqLen, startPos int, scale float64, hStart, hEnd int) {
+	strideQ := numHeads * headDimK
+	strideK := numKVHeads * headDimK
+	strideV := numKVHeads * headDimV
+	strideOut := numHeads * headDimV
+
+	for h := hStart; h < hEnd; h++ {
+		qHeadOffset := h * headDimK
+		outHeadOffset := h * headDimV
+		kvHead := h / groupSize
+		kHeadOffset := kvHead * headDimK
+		vHeadOffset := kvHead * headDimV
+
+		for qi := 0; qi < newTokens; qi++ {
+			maxKeyPos := startPos + qi + 1
+			if maxKeyPos > totalSeqLen {
+				maxKeyPos = totalSeqLen
+			}
+
+			qBase := qi*strideQ + qHeadOffset
+			qPtr := &qData[qBase]
+
+			outBase := qi*strideOut + outHeadOffset
+			outVec := outData[outBase : outBase+headDimV]
+			outPtr := &outData[outBase]
+			clear(outVec)
+
+			m := float32(-math.MaxFloat32)
+			l := float32(0)
+			for ti := 0; ti < maxKeyPos; ti++ {
+				kBase := ti*strideK + kHeadOffset
+				kPtr := &kData[kBase]
+				s := cpu.DotFloat32Ptr(qPtr, kPtr, headDimK) * float32(scale)
+				vBase := ti*strideV + vHeadOffset
+				vPtr := &vData[vBase]
+
+				if s > m {
+					alpha := expf(m - s)
+					if l != 0 {
+						for i := 0; i < headDimV; i++ {
+							outVec[i] *= alpha
+						}
+						l *= alpha
+					}
+					m = s
+					l += 1
+					cpu.AxpyPtr(1, vPtr, outPtr, headDimV)
+					continue
+				}
+
+				w := expf(s - m)
+				l += w
+				cpu.AxpyPtr(w, vPtr, outPtr, headDimV)
+			}
+			if l != 0 {
+				inv := 1 / l
+				for i := 0; i < headDimV; i++ {
+					outVec[i] *= inv
+				}
+			}
+		}
+	}
+}

+ 117 - 0
pkg/backend/cpu/nn/attention_cached_test.go

@@ -0,0 +1,117 @@
+package nn
+
+import (
+	"math"
+	"testing"
+
+	"makarna/pkg/backend/cpu"
+	"makarna/pkg/kvcache"
+	"makarna/pkg/tensor"
+)
+
+func TestCausalAttentionBlocksMatchesContiguous(t *testing.T) {
+	numHeads, numKVHeads, headDim := 1, 1, 1
+	startPos := 1
+
+	// Two new tokens attending over three total tokens (one past + two new).
+	q := cpu.NewTensor(tensor.Shape{2, 1}, []float32{0.5, 1.0})
+	kAll := cpu.NewTensor(tensor.Shape{3, 1}, []float32{1, 2, 3})
+	vAll := cpu.NewTensor(tensor.Shape{3, 1}, []float32{10, 20, 30})
+
+	outContig := cpu.NewTensor(tensor.Shape{2, 1}, nil)
+	if err := CausalAttentionCached(q, kAll, vAll, outContig, numHeads, numKVHeads, headDim, startPos); err != nil {
+		t.Fatalf("contiguous attention failed: %v", err)
+	}
+
+	// Build block view equivalent to the contiguous tensors.
+	blockK := cpu.NewTensor(tensor.Shape{4, 1}, []float32{1, 2, 3, 0})
+	blockV := cpu.NewTensor(tensor.Shape{4, 1}, []float32{10, 20, 30, 0})
+	view := kvcache.View{
+		K:      blockK,
+		V:      blockV,
+		Start:  0,
+		Length: 3,
+		Device: tensor.CPU,
+	}
+
+	outBlocks := cpu.NewTensor(tensor.Shape{2, 1}, nil)
+	if err := CausalAttentionBlocks(q, []kvcache.View{view}, outBlocks, numHeads, numKVHeads, headDim, startPos); err != nil {
+		t.Fatalf("block attention failed: %v", err)
+	}
+
+	for i := range outContig.DataFloat32() {
+		if diff := math.Abs(float64(outContig.DataFloat32()[i] - outBlocks.DataFloat32()[i])); diff > 1e-5 {
+			t.Fatalf("mismatch at %d: contiguous=%v blocks=%v", i, outContig.DataFloat32()[i], outBlocks.DataFloat32()[i])
+		}
+	}
+}
+
+func TestCausalAttentionPackedMatchesBlocks(t *testing.T) {
+	numHeads, numKVHeads, headDim := 4, 2, 8
+	newTokens := 2
+	startPos := 4
+	blockSize := 8
+	kvDim := numKVHeads * headDim
+
+	qData := make([]float32, newTokens*numHeads*headDim)
+	for i := range qData {
+		qData[i] = float32(i%7) / 7
+	}
+	q := cpu.NewTensor(tensor.Shape{newTokens, numHeads * headDim}, qData)
+
+	// total KV length includes past + current
+	total := startPos + newTokens
+	kData := make([]float32, total*kvDim)
+	vData := make([]float32, total*kvDim)
+	for i := range kData {
+		kData[i] = float32((i%17)-8) / 9
+		vData[i] = float32((i%19)-9) / 10
+	}
+
+	views := make([]kvcache.View, 0, (total+blockSize-1)/blockSize)
+	pviews := make([]kvcache.PackedView, 0, (total+blockSize-1)/blockSize)
+	for start := 0; start < total; start += blockSize {
+		length := blockSize
+		if start+length > total {
+			length = total - start
+		}
+		kBlkData := make([]float32, blockSize*kvDim)
+		vBlkData := make([]float32, blockSize*kvDim)
+		copy(kBlkData, kData[start*kvDim:(start+length)*kvDim])
+		copy(vBlkData, vData[start*kvDim:(start+length)*kvDim])
+		kBlk := cpu.NewTensor(tensor.Shape{blockSize, kvDim}, kBlkData)
+		vBlk := cpu.NewTensor(tensor.Shape{blockSize, kvDim}, vBlkData)
+		views = append(views, kvcache.View{K: kBlk, V: vBlk, Start: start, Length: length, Device: tensor.CPU})
+
+		pk := make([]float32, numKVHeads*blockSize*headDim)
+		pv := make([]float32, numKVHeads*blockSize*headDim)
+		for ti := 0; ti < length; ti++ {
+			baseTok := (start + ti) * kvDim
+			for h := 0; h < numKVHeads; h++ {
+				srcBase := baseTok + h*headDim
+				dstBase := h*(blockSize*headDim) + ti*headDim
+				copy(pk[dstBase:dstBase+headDim], kData[srcBase:srcBase+headDim])
+				copy(pv[dstBase:dstBase+headDim], vData[srcBase:srcBase+headDim])
+			}
+		}
+		pviews = append(pviews, kvcache.PackedView{K: pk, V: pv, Start: start, Length: length, BlockSize: blockSize, HeadDim: headDim, NumKVHeads: numKVHeads})
+	}
+
+	outBlocks := cpu.NewTensor(tensor.Shape{newTokens, numHeads * headDim}, nil)
+	outPacked := cpu.NewTensor(tensor.Shape{newTokens, numHeads * headDim}, nil)
+
+	if err := CausalAttentionBlocks(q, views, outBlocks, numHeads, numKVHeads, headDim, startPos); err != nil {
+		t.Fatalf("blocks attention failed: %v", err)
+	}
+	if err := CausalAttentionPackedBlocks(q, pviews, outPacked, numHeads, numKVHeads, headDim, startPos); err != nil {
+		t.Fatalf("packed attention failed: %v", err)
+	}
+
+	for i := range outBlocks.DataFloat32() {
+		diff := math.Abs(float64(outBlocks.DataFloat32()[i] - outPacked.DataFloat32()[i]))
+		if diff > 1e-5 {
+			t.Fatalf("mismatch at %d: blocks=%v packed=%v", i, outBlocks.DataFloat32()[i], outPacked.DataFloat32()[i])
+		}
+	}
+}
+

+ 128 - 0
pkg/backend/cpu/nn/conv1d.go

@@ -0,0 +1,128 @@
+package nn
+
+import (
+	"fmt"
+	"math"
+
+	"makarna/pkg/backend/cpu"
+)
+
+type ActivationKind uint8
+
+const (
+	ActivationNone ActivationKind = iota
+	ActivationSiLU
+	ActivationTanh
+	ActivationReLU
+)
+
+func applyActivation(x float32, act ActivationKind) float32 {
+	switch act {
+	case ActivationNone:
+		return x
+	case ActivationSiLU:
+		return x * Sigmoid(x)
+	case ActivationReLU:
+		if x < 0 {
+			return 0
+		}
+		return x
+	case ActivationTanh:
+		return float32(math.Tanh(float64(x)))
+	default:
+		return x
+	}
+}
+
+func FlattenConvWeights(w *cpu.Tensor, projSize int, kernel int) ([]float32, error) {
+	if w == nil {
+		return nil, fmt.Errorf("missing conv weights")
+	}
+	data := w.DataFloat32()
+	expected := projSize * kernel
+	shape := w.Shape()
+	if shape.NumElements() < expected || len(data) < expected {
+		return nil, fmt.Errorf("unexpected conv weight size %d", len(data))
+	}
+	if len(shape) == 2 {
+		if shape[0] == projSize && shape[1] == kernel {
+			return data[:expected], nil
+		}
+		if shape[0] == kernel && shape[1] == projSize {
+			out := make([]float32, expected)
+			for d := 0; d < projSize; d++ {
+				for j := 0; j < kernel; j++ {
+					out[d*kernel+j] = data[j*projSize+d]
+				}
+			}
+			return out, nil
+		}
+	}
+	if len(shape) == 3 {
+		if shape[0] == projSize && shape[1] == 1 && shape[2] == kernel {
+			return data[:expected], nil
+		}
+		if shape[0] == kernel && shape[1] == 1 && shape[2] == projSize {
+			out := make([]float32, expected)
+			for d := 0; d < projSize; d++ {
+				for j := 0; j < kernel; j++ {
+					out[d*kernel+j] = data[j*projSize+d]
+				}
+			}
+			return out, nil
+		}
+	}
+	if len(data) >= expected {
+		return data[:expected], nil
+	}
+	return nil, fmt.Errorf("unexpected conv weight size %d", len(data))
+}
+
+func CausalShortConv1DInplaceAct(xFlat []float32, state *cpu.Tensor, w *cpu.Tensor, tokens int, projSize int, kernel int, act ActivationKind) error {
+	if kernel <= 1 {
+		for i := range xFlat {
+			xFlat[i] = applyActivation(xFlat[i], act)
+		}
+		return nil
+	}
+	convLen := kernel - 1
+	if state == nil {
+		return fmt.Errorf("nil conv state")
+	}
+	if state.Shape().NumElements() != projSize*convLen {
+		return fmt.Errorf("conv state shape mismatch %v", state.Shape())
+	}
+	weights, err := FlattenConvWeights(w, projSize, kernel)
+	if err != nil {
+		return err
+	}
+	st := state.DataFloat32()
+	out := make([]float32, len(xFlat))
+	for t := 0; t < tokens; t++ {
+		base := t * projSize
+		for d := 0; d < projSize; d++ {
+			acc := float32(0)
+			wBase := d * kernel
+			for j := 0; j < convLen; j++ {
+				acc += weights[wBase+j] * st[d*convLen+j]
+			}
+			acc += weights[wBase+convLen] * xFlat[base+d]
+			out[base+d] = applyActivation(acc, act)
+		}
+		if convLen > 0 {
+			for d := 0; d < projSize; d++ {
+				off := d * convLen
+				copy(st[off:off+convLen-1], st[off+1:off+convLen])
+				st[off+convLen-1] = xFlat[base+d]
+			}
+		}
+	}
+	copy(xFlat, out)
+	return nil
+}
+
+// CausalShortConv1DInplace is the backward-compatible API.
+// It applies a causal short conv1d followed by SiLU.
+func CausalShortConv1DInplace(xFlat []float32, state *cpu.Tensor, w *cpu.Tensor, tokens int, projSize int, kernel int) error {
+	return CausalShortConv1DInplaceAct(xFlat, state, w, tokens, projSize, kernel, ActivationSiLU)
+}

+ 229 - 0
pkg/backend/cpu/nn/embedding.go

@@ -0,0 +1,229 @@
+package nn
+
+import (
+	"fmt"
+
+	"makarna/pkg/backend/cpu"
+	"makarna/pkg/tensor"
+)
+
+// Embedding looks up token embeddings
+// ids: token IDs
+// weight: [vocab_size, dim]
+// out: [seq_len, dim]
+func Embedding(ids []int, weight, out *cpu.Tensor) error {
+	inShape := weight.Shape()
+	if len(inShape) != 2 {
+		return fmt.Errorf("embedding: expected 2D weight, got %v", inShape)
+	}
+
+	vocabSize := inShape[0]
+	dim := inShape[1]
+
+	oData := out.DataFloat32()
+
+	// Validate output shape
+	outShape := out.Shape()
+	if outShape[0] != len(ids) || outShape[1] != dim {
+		return fmt.Errorf("embedding: output shape mismatch: expected [%d, %d], got %v", len(ids), dim, outShape)
+	}
+
+	switch weight.DType() {
+	case tensor.Float32:
+		wData := weight.DataFloat32()
+		for i, id := range ids {
+			if id < 0 || id >= vocabSize {
+				return fmt.Errorf("embedding: id %d out of range [0, %d)", id, vocabSize)
+			}
+			src := wData[id*dim : (id+1)*dim]
+			dst := oData[i*dim : (i+1)*dim]
+			copy(dst, src)
+		}
+
+	case tensor.Float16:
+		wData := weight.DataUint16()
+		for i, id := range ids {
+			if id < 0 || id >= vocabSize {
+				return fmt.Errorf("embedding: id %d out of range [0, %d)", id, vocabSize)
+			}
+			src := wData[id*dim : (id+1)*dim]
+			dst := oData[i*dim : (i+1)*dim]
+			for j := 0; j < dim; j++ {
+				dst[j] = float16BitsToFloat32(src[j])
+			}
+		}
+
+	case tensor.BFloat16:
+		wData := weight.DataUint16()
+		for i, id := range ids {
+			if id < 0 || id >= vocabSize {
+				return fmt.Errorf("embedding: id %d out of range [0, %d)", id, vocabSize)
+			}
+			src := wData[id*dim : (id+1)*dim]
+			dst := oData[i*dim : (i+1)*dim]
+			for j := 0; j < dim; j++ {
+				dst[j] = bfloat16BitsToFloat32(src[j])
+			}
+		}
+
+	case tensor.Q4_K:
+		const blockSize = 256
+
+		if dim%blockSize != 0 {
+			return fmt.Errorf("embedding: Q4_K dim %d must be multiple of %d", dim, blockSize)
+		}
+
+		wData := weight.DataQ4_K()
+		blocksPerDim := dim / blockSize
+		var deqBuf [256]float32
+
+		for i, id := range ids {
+			if id < 0 || id >= vocabSize {
+				return fmt.Errorf("embedding: id %d out of range [0, %d)", id, vocabSize)
+			}
+
+			dst := oData[i*dim : (i+1)*dim]
+			blockStart := id * blocksPerDim
+
+			for b := 0; b < blocksPerDim; b++ {
+				block := &wData[blockStart+b]
+				tensor.DequantizeQ4_K(block, deqBuf[:])
+				copy(dst[b*blockSize:], deqBuf[:])
+			}
+		}
+
+	case tensor.Q8_K:
+		const blockSize = 256
+
+		if dim%blockSize != 0 {
+			return fmt.Errorf("embedding: Q8_K dim %d must be multiple of %d", dim, blockSize)
+		}
+
+		wData := weight.DataQ8_K()
+		blocksPerDim := dim / blockSize
+		var deqBuf [256]float32
+
+		for i, id := range ids {
+			if id < 0 || id >= vocabSize {
+				return fmt.Errorf("embedding: id %d out of range [0, %d)", id, vocabSize)
+			}
+
+			dst := oData[i*dim : (i+1)*dim]
+			blockStart := id * blocksPerDim
+
+			for b := 0; b < blocksPerDim; b++ {
+				block := &wData[blockStart+b]
+				tensor.DequantizeQ8_K(block, deqBuf[:])
+				copy(dst[b*blockSize:], deqBuf[:])
+			}
+		}
+
+	case tensor.Q3_K:
+		const blockSize = 256
+
+		if dim%blockSize != 0 {
+			return fmt.Errorf("embedding: Q3_K dim %d must be multiple of %d", dim, blockSize)
+		}
+
+		wData := weight.DataQ3_K()
+		blocksPerDim := dim / blockSize
+		var deqBuf [256]float32
+
+		for i, id := range ids {
+			if id < 0 || id >= vocabSize {
+				return fmt.Errorf("embedding: id %d out of range [0, %d)", id, vocabSize)
+			}
+
+			dst := oData[i*dim : (i+1)*dim]
+			blockStart := id * blocksPerDim
+
+			for b := 0; b < blocksPerDim; b++ {
+				block := &wData[blockStart+b]
+				tensor.DequantizeQ3_K(block, deqBuf[:])
+				copy(dst[b*blockSize:], deqBuf[:])
+			}
+		}
+
+	case tensor.Q5_K:
+		const blockSize = 256
+
+		if dim%blockSize != 0 {
+			return fmt.Errorf("embedding: Q5_K dim %d must be multiple of %d", dim, blockSize)
+		}
+
+		wData := weight.DataQ5_K()
+		blocksPerDim := dim / blockSize
+		var deqBuf [256]float32
+
+		for i, id := range ids {
+			if id < 0 || id >= vocabSize {
+				return fmt.Errorf("embedding: id %d out of range [0, %d)", id, vocabSize)
+			}
+
+			dst := oData[i*dim : (i+1)*dim]
+			blockStart := id * blocksPerDim
+
+			for b := 0; b < blocksPerDim; b++ {
+				block := &wData[blockStart+b]
+				tensor.DequantizeQ5_K(block, deqBuf[:])
+				copy(dst[b*blockSize:], deqBuf[:])
+			}
+		}
+
+	case tensor.Q6_K:
+		const blockSize = 256
+
+		if dim%blockSize != 0 {
+			return fmt.Errorf("embedding: Q6_K dim %d must be multiple of %d", dim, blockSize)
+		}
+
+		wData := weight.DataQ6_K()
+		blocksPerDim := dim / blockSize
+
+		for i, id := range ids {
+			if id < 0 || id >= vocabSize {
+				return fmt.Errorf("embedding: id %d out of range [0, %d)", id, vocabSize)
+			}
+
+			dst := oData[i*dim : (i+1)*dim]
+			blockStart := id * blocksPerDim
+
+			for b := 0; b < blocksPerDim; b++ {
+				block := &wData[blockStart+b]
+				seg := dst[b*blockSize : (b+1)*blockSize]
+				tensor.DequantizeQ6_K(block, seg)
+			}
+		}
+
+	case tensor.Q2_K:
+		const blockSize = 256
+
+		if dim%blockSize != 0 {
+			return fmt.Errorf("embedding: Q2_K dim %d must be multiple of %d", dim, blockSize)
+		}
+
+		wData := weight.DataQ2_K()
+		blocksPerDim := dim / blockSize
+		var deqBuf [256]float32
+
+		for i, id := range ids {
+			if id < 0 || id >= vocabSize {
+				return fmt.Errorf("embedding: id %d out of range [0, %d)", id, vocabSize)
+			}
+
+			dst := oData[i*dim : (i+1)*dim]
+			blockStart := id * blocksPerDim
+
+			for b := 0; b < blocksPerDim; b++ {
+				block := &wData[blockStart+b]
+				tensor.DequantizeQ2_K(block, deqBuf[:])
+				copy(dst[b*blockSize:], deqBuf[:])
+			}
+		}
+
+	default:
+		return fmt.Errorf("embedding: unsupported weight dtype %v", weight.DType())
+	}
+
+	return nil
+}

+ 42 - 0
pkg/backend/cpu/nn/input.go

@@ -0,0 +1,42 @@
+package nn
+
+import (
+	"unsafe"
+
+	"makarna/pkg/tensor"
+)
+
+// ParseTokenIDs extracts integer token IDs from a float32 tensor
+// Input tensor shape: [seqLen]
+func ParseTokenIDs(input tensor.Tensor) []int {
+	seqLen := input.Shape()[0]
+	ptr := input.Data().(unsafe.Pointer)
+	slice := unsafe.Slice((*float32)(ptr), seqLen)
+
+	ids := make([]int, seqLen)
+	for i, v := range slice {
+		ids[i] = int(v)
+	}
+	return ids
+}
+
+// ParsePositions extracts position indices from tensor or generates default [0,1,2,...]
+// If positions is nil, generates sequential positions starting from 0
+func ParsePositions(positions tensor.Tensor, seqLen int) []int {
+	if positions == nil {
+		arr := make([]int, seqLen)
+		for i := range arr {
+			arr[i] = i
+		}
+		return arr
+	}
+
+	ptr := positions.Data().(unsafe.Pointer)
+	slice := unsafe.Slice((*float32)(ptr), seqLen)
+
+	arr := make([]int, seqLen)
+	for i, v := range slice {
+		arr[i] = int(v)
+	}
+	return arr
+}

+ 114 - 0
pkg/backend/cpu/nn/kda.go

@@ -0,0 +1,114 @@
+package nn
+
+import (
+	"fmt"
+	"math"
+
+	"makarna/pkg/backend/cpu"
+)
+
+func KDAGate(gFlat []float32, aLog []float32, headDim int, dtBias []float32) []float32 {
+	h := len(aLog)
+	if h*headDim != len(gFlat) {
+		return nil
+	}
+	out := make([]float32, len(gFlat))
+	for hi := 0; hi < h; hi++ {
+		mul := -float32(math.Exp(float64(aLog[hi])))
+		base := hi * headDim
+		for d := 0; d < headDim; d++ {
+			x := gFlat[base+d]
+			if dtBias != nil {
+				x += dtBias[base+d]
+			}
+			out[base+d] = mul * Softplus(x)
+		}
+	}
+	return out
+}
+
+func KDARecurrent(qFlat, kFlat, vFlat, gFlat, beta []float32, state []float32, tokens, numHeads, headDim int) error {
+	stride := numHeads * headDim
+	strideB := numHeads
+	stateStride := headDim * headDim
+	if len(qFlat) < tokens*stride || len(kFlat) < tokens*stride || len(vFlat) < tokens*stride || len(gFlat) < tokens*stride {
+		return fmt.Errorf("KDARecurrent: input size mismatch")
+	}
+	if len(beta) < tokens*strideB {
+		return fmt.Errorf("KDARecurrent: beta size mismatch")
+	}
+	if state == nil || len(state) != numHeads*stateStride {
+		return fmt.Errorf("KDARecurrent: state size mismatch")
+	}
+	scale := float32(1.0 / math.Sqrt(float64(headDim)))
+
+	tmpKV := make([]float32, headDim)
+	tmpVM := make([]float32, headDim)
+
+	for t := 0; t < tokens; t++ {
+		for h := 0; h < numHeads; h++ {
+			off := t*stride + h*headDim
+			b := beta[t*strideB+h]
+			SOff := h * stateStride
+
+			for kk := 0; kk < headDim; kk++ {
+				dec := float32(math.Exp(float64(gFlat[off+kk])))
+				rowBase := SOff + kk*headDim
+				for vv := 0; vv < headDim; vv++ {
+					state[rowBase+vv] *= dec
+				}
+			}
+
+			for vv := 0; vv < headDim; vv++ {
+				acc := float32(0)
+				for kk := 0; kk < headDim; kk++ {
+					acc += kFlat[off+kk] * state[SOff+kk*headDim+vv]
+				}
+				tmpKV[vv] = acc
+			}
+			for vv := 0; vv < headDim; vv++ {
+				tmpVM[vv] = vFlat[off+vv] - tmpKV[vv]
+			}
+			for kk := 0; kk < headDim; kk++ {
+				kj := b * kFlat[off+kk]
+				row := state[SOff+kk*headDim : SOff+(kk+1)*headDim]
+				cpu.Axpy(kj, tmpVM, row)
+			}
+
+			for vv := 0; vv < headDim; vv++ {
+				acc := float32(0)
+				for kk := 0; kk < headDim; kk++ {
+					acc += (qFlat[off+kk] * scale) * state[SOff+kk*headDim+vv]
+				}
+				vFlat[off+vv] = acc
+			}
+		}
+	}
+	return nil
+}
+
+func RMSNormGated(out []float32, g []float32, weight []float32, headDim int, eps float32) {
+	if weight == nil {
+		return
+	}
+	for i := 0; i < len(out); i += headDim {
+		ss := float32(0)
+		for j := 0; j < headDim; j++ {
+			v := out[i+j]
+			ss += v * v
+		}
+		inv := float32(1.0 / math.Sqrt(float64(ss/float32(headDim)+eps)))
+		for j := 0; j < headDim; j++ {
+			y := out[i+j] * inv * weight[j]
+			if g != nil {
+				y *= Sigmoid(g[i+j])
+			}
+			out[i+j] = y
+		}
+	}
+}
+
+// FlattenALog is a thin convenience wrapper around FlattenVector.
+func FlattenALog(t *cpu.Tensor, numHeads int) ([]float32, error) {
+	return FlattenVector(t, numHeads, "A_log")
+}

+ 32 - 0
pkg/backend/cpu/nn/mlp.go

@@ -0,0 +1,32 @@
+package nn
+
+import (
+	"makarna/pkg/backend/cpu"
+	"makarna/pkg/backend/cpu/matmul"
+	"makarna/pkg/backend/cpu/ops"
+	"makarna/pkg/tensor"
+)
+
+// SwiGLUMLP computes the SwiGLU MLP block:
+// gate = x @ wGate
+// up = x @ wUp
+// hidden = SwiGLU(gate, up)
+// out = hidden @ wDown
+func SwiGLUMLP(x, wGate, wUp, wDown *cpu.Tensor) *cpu.Tensor {
+	seqLen := x.Shape()[0]
+	intermediate := wGate.Shape()[0]
+	hiddenSize := wDown.Shape()[0]
+
+	gate := ops.Zeros(tensor.Shape{seqLen, intermediate})
+	up := ops.Zeros(tensor.Shape{seqLen, intermediate})
+
+	matmul.Linear(x, wGate, gate)
+	matmul.Linear(x, wUp, up)
+
+	SwiGLU(gate, up, gate)
+
+	out := ops.Zeros(tensor.Shape{seqLen, hiddenSize})
+	matmul.Linear(gate, wDown, out)
+
+	return out
+}

+ 181 - 0
pkg/backend/cpu/nn/moe.go

@@ -0,0 +1,181 @@
+package nn
+
+import (
+	"sort"
+)
+
+// MoEChoice represents a selected expert with its weight.
+type MoEChoice struct {
+	Idx    int
+	Weight float32
+}
+
+// TopKIndices returns the indices of the top-k largest values in scores.
+func TopKIndices(scores []float32, k int) []int {
+	if k <= 0 {
+		k = 1
+	}
+	if k > len(scores) {
+		k = len(scores)
+	}
+	idx := make([]int, len(scores))
+	for i := range idx {
+		idx[i] = i
+	}
+	sort.Slice(idx, func(i, j int) bool { return scores[idx[i]] > scores[idx[j]] })
+	return idx[:k]
+}
+
+// SelectTopKExperts selects top-k experts from scores and returns their indices and weights.
+// If useOriginalWeights is provided, weights are taken from that slice instead of scores.
+func SelectTopKExperts(scores []float32, k int, useOriginalWeights []float32) []MoEChoice {
+	if k <= 0 {
+		k = 1
+	}
+	if k > len(scores) {
+		k = len(scores)
+	}
+
+	type scored struct {
+		idx   int
+		score float32
+	}
+	choices := make([]scored, len(scores))
+	for i, s := range scores {
+		choices[i] = scored{idx: i, score: s}
+	}
+	sort.Slice(choices, func(i, j int) bool { return choices[i].score > choices[j].score })
+
+	result := make([]MoEChoice, k)
+	for i := 0; i < k; i++ {
+		idx := choices[i].idx
+		w := scores[idx]
+		if useOriginalWeights != nil {
+			w = useOriginalWeights[idx]
+		}
+		result[i] = MoEChoice{Idx: idx, Weight: w}
+	}
+	return result
+}
+
+// GroupedTopKMask applies grouped top-k selection as in DeepSeek-V3/Kimi MoE.
+// Returns a masked score array where only experts from selected groups are non-zero.
+//
+// Parameters:
+//   - scores: router scores with bias already added (for selection)
+//   - numGroups: number of expert groups
+//   - topKGroup: how many groups to keep
+//
+// Returns masked scores suitable for final top-k selection.
+func GroupedTopKMask(scores []float32, numGroups, topKGroup int) []float32 {
+	if numGroups <= 0 {
+		numGroups = 1
+	}
+	numExperts := len(scores)
+	if numExperts%numGroups != 0 {
+		return scores // fallback: no masking
+	}
+
+	perGroup := numExperts / numGroups
+
+	// Compute group scores (sum of top-2 in each group)
+	groupScores := make([]float32, numGroups)
+	for gi := 0; gi < numGroups; gi++ {
+		base := gi * perGroup
+		seg := scores[base : base+perGroup]
+		top2Idx := TopKIndices(seg, 2)
+		s := float32(0)
+		for _, id := range top2Idx {
+			s += seg[id]
+		}
+		groupScores[gi] = s
+	}
+
+	// Select top-k groups
+	keepGroups := TopKIndices(groupScores, topKGroup)
+	keep := make([]bool, numGroups)
+	for _, gi := range keepGroups {
+		keep[gi] = true
+	}
+
+	// Create masked output
+	masked := make([]float32, numExperts)
+	for gi := 0; gi < numGroups; gi++ {
+		base := gi * perGroup
+		if keep[gi] {
+			copy(masked[base:base+perGroup], scores[base:base+perGroup])
+		}
+		// else: zeros (already initialized)
+	}
+	return masked
+}
+
+// RenormalizeMoEWeights normalizes weights to sum to 1.
+func RenormalizeMoEWeights(choices []MoEChoice) {
+	if len(choices) == 0 {
+		return
+	}
+	sum := float32(0)
+	for _, c := range choices {
+		sum += c.Weight
+	}
+	if sum == 0 {
+		return
+	}
+	inv := 1 / sum
+	for i := range choices {
+		choices[i].Weight *= inv
+	}
+}
+
+// ScaleMoEWeights multiplies all weights by the given factor.
+func ScaleMoEWeights(choices []MoEChoice, factor float32) {
+	for i := range choices {
+		choices[i].Weight *= factor
+	}
+}
+
+// MoERouterActivation applies activation function to router logits.
+func MoERouterActivation(logits []float32, activationFunc string) []float32 {
+	scores := make([]float32, len(logits))
+	switch activationFunc {
+	case "sigmoid":
+		for i, v := range logits {
+			scores[i] = Sigmoid(v)
+		}
+	case "softmax":
+		copy(scores, logits)
+		SoftmaxInplaceSimple(scores)
+	default:
+		// Default to sigmoid
+		for i, v := range logits {
+			scores[i] = Sigmoid(v)
+		}
+	}
+	return scores
+}
+
+// SoftmaxInplaceSimple applies softmax normalization in-place (simple scalar version).
+func SoftmaxInplaceSimple(data []float32) {
+	if len(data) == 0 {
+		return
+	}
+	maxVal := data[0]
+	for _, v := range data[1:] {
+		if v > maxVal {
+			maxVal = v
+		}
+	}
+	sum := float32(0)
+	for i := range data {
+		data[i] = Exp(data[i] - maxVal)
+		sum += data[i]
+	}
+	if sum == 0 {
+		return
+	}
+	inv := 1 / sum
+	for i := range data {
+		data[i] *= inv
+	}
+}

+ 168 - 0
pkg/backend/cpu/nn/nn_bench_test.go

@@ -0,0 +1,168 @@
+package nn
+
+import (
+	"math/rand"
+	"strconv"
+	"testing"
+
+	"makarna/pkg/backend/cpu"
+	"makarna/pkg/kvcache"
+	"makarna/pkg/tensor"
+)
+
+func BenchmarkRMSNorm(b *testing.B) {
+	dim := 256
+	rows := 16
+	data := make([]float32, rows*dim)
+	for i := range data {
+		data[i] = rand.Float32()
+	}
+	w := make([]float32, dim)
+	for i := range w {
+		w[i] = 1
+	}
+	x := cpu.NewTensor(tensor.Shape{rows, dim}, data)
+	ws := cpu.NewTensor(tensor.Shape{dim}, w)
+
+	b.ReportAllocs()
+	b.ResetTimer()
+
+	for i := 0; i < b.N; i++ {
+		RMSNorm(x, ws, 1e-5)
+	}
+}
+
+func BenchmarkCausalAttentionPackedVsBlocks(b *testing.B) {
+	// Decode-like: newTokens=1, numHeads=16, numKVHeads=8, headDim=128.
+	numHeads, numKVHeads, headDim := 16, 8, 128
+	newTokens := 1
+	kvDim := numKVHeads * headDim
+	blockSize := 16
+
+	x := make([]float32, newTokens*numHeads*headDim)
+	for i := range x {
+		x[i] = rand.Float32() - 0.5
+	}
+	q := cpu.NewTensor(tensor.Shape{newTokens, numHeads * headDim}, x)
+	out := cpu.NewTensor(tensor.Shape{newTokens, numHeads * headDim}, nil)
+
+	for _, kvLen := range []int{256, 1024, 4096, 16384} {
+		b.Run("kvLen="+strconv.Itoa(kvLen), func(b *testing.B) {
+			kData := make([]float32, kvLen*kvDim)
+			vData := make([]float32, kvLen*kvDim)
+			for i := range kData {
+				kData[i] = rand.Float32() - 0.5
+			}
+			for i := range vData {
+				vData[i] = rand.Float32() - 0.5
+			}
+
+			// Build both view types over the same data.
+			views := make([]kvcache.View, 0, (kvLen+blockSize-1)/blockSize)
+			pviews := make([]kvcache.PackedView, 0, (kvLen+blockSize-1)/blockSize)
+			for start := 0; start < kvLen; start += blockSize {
+				length := blockSize
+				if start+length > kvLen {
+					length = kvLen - start
+				}
+				// token-major blocks
+				kBlk := cpu.NewTensor(tensor.Shape{blockSize, kvDim}, kData[start*kvDim:(start+blockSize)*kvDim])
+				vBlk := cpu.NewTensor(tensor.Shape{blockSize, kvDim}, vData[start*kvDim:(start+blockSize)*kvDim])
+				views = append(views, kvcache.View{K: kBlk, V: vBlk, Start: start, Length: length, Device: tensor.CPU})
+
+				// packed blocks
+				pk := make([]float32, numKVHeads*blockSize*headDim)
+				pv := make([]float32, numKVHeads*blockSize*headDim)
+				// write packed
+				for t := 0; t < length; t++ {
+					baseTok := (start + t) * kvDim
+					for h := 0; h < numKVHeads; h++ {
+						srcBase := baseTok + h*headDim
+						dstBase := h*(blockSize*headDim) + t*headDim
+						copy(pk[dstBase:dstBase+headDim], kData[srcBase:srcBase+headDim])
+						copy(pv[dstBase:dstBase+headDim], vData[srcBase:srcBase+headDim])
+					}
+				}
+				pviews = append(pviews, kvcache.PackedView{K: pk, V: pv, Start: start, Length: length, BlockSize: blockSize, HeadDim: headDim, NumKVHeads: numKVHeads})
+			}
+
+			for _, fast := range []bool{false, true} {
+				name := "exp=exact"
+				if fast {
+					name = "exp=fast"
+				}
+				b.Run(name, func(b *testing.B) {
+					orig := useFastExp
+					useFastExp = fast
+					defer func() { useFastExp = orig }()
+
+					b.Run("blocks", func(b *testing.B) {
+						b.ReportAllocs()
+						b.ResetTimer()
+						for i := 0; i < b.N; i++ {
+							_ = CausalAttentionBlocks(q, views, out, numHeads, numKVHeads, headDim, kvLen-1)
+						}
+					})
+					b.Run("packed", func(b *testing.B) {
+						b.ReportAllocs()
+						b.ResetTimer()
+						for i := 0; i < b.N; i++ {
+							_ = CausalAttentionPackedBlocks(q, pviews, out, numHeads, numKVHeads, headDim, kvLen-1)
+						}
+					})
+				})
+			}
+		})
+	}
+}
+
+func BenchmarkSoftmax(b *testing.B) {
+	n := 512
+	data := make([]float32, n)
+	for i := range data {
+		data[i] = rand.Float32()
+	}
+	x := cpu.NewTensor(tensor.Shape{n}, data)
+
+	b.ReportAllocs()
+	b.ResetTimer()
+	for i := 0; i < b.N; i++ {
+		Softmax(x)
+	}
+}
+
+func BenchmarkCausalAttentionCached_Fused(b *testing.B) {
+	// Decode-like: newTokens=1, numHeads=16, numKVHeads=8, headDim=128.
+	numHeads, numKVHeads, headDim := 16, 8, 128
+	newTokens := 1
+	kvDim := numKVHeads * headDim
+
+	x := make([]float32, newTokens*numHeads*headDim)
+	for i := range x {
+		x[i] = rand.Float32() - 0.5
+	}
+	q := cpu.NewTensor(tensor.Shape{newTokens, numHeads * headDim}, x)
+
+	for _, kvLen := range []int{256, 1024, 4096, 16384} {
+		b.Run("kvLen="+strconv.Itoa(kvLen), func(b *testing.B) {
+			kData := make([]float32, kvLen*kvDim)
+			vData := make([]float32, kvLen*kvDim)
+			for i := range kData {
+				kData[i] = rand.Float32() - 0.5
+			}
+			for i := range vData {
+				vData[i] = rand.Float32() - 0.5
+			}
+			k := cpu.NewTensor(tensor.Shape{kvLen, kvDim}, kData)
+			v := cpu.NewTensor(tensor.Shape{kvLen, kvDim}, vData)
+			out := cpu.NewTensor(tensor.Shape{newTokens, numHeads * headDim}, nil)
+
+			b.ReportAllocs()
+			b.ResetTimer()
+			for i := 0; i < b.N; i++ {
+				_ = CausalAttentionCached(q, k, v, out, numHeads, numKVHeads, headDim, kvLen-1)
+			}
+		})
+	}
+}
+

+ 84 - 0
pkg/backend/cpu/nn/nn_simd_test.go

@@ -0,0 +1,84 @@
+package nn
+
+import (
+	"math"
+	"math/rand"
+	"testing"
+
+	"makarna/pkg/backend/cpu"
+	"makarna/pkg/tensor"
+)
+
+func TestRMSNormMatchesReference(t *testing.T) {
+	dim := 8
+	x := cpu.NewTensor(tensor.Shape{1, dim}, []float32{1, 2, 3, 4, 5, 6, 7, 8})
+	w := cpu.NewTensor(tensor.Shape{dim}, make([]float32, dim))
+	for i := range w.DataFloat32() {
+		w.DataFloat32()[i] = 1
+	}
+
+	if err := RMSNorm(x, w, 1e-5); err != nil {
+		t.Fatalf("rmsnorm err: %v", err)
+	}
+
+	// Reference
+	row := x.DataFloat32()
+	var ss float32
+	for _, v := range []float32{1, 2, 3, 4, 5, 6, 7, 8} {
+		ss += v * v
+	}
+	ss /= float32(dim)
+	inv := 1 / float32(math.Sqrt(float64(ss+1e-5)))
+
+	for i, v := range []float32{1, 2, 3, 4, 5, 6, 7, 8} {
+		want := v * inv
+		if diff := absDiff(row[i], want); diff > 1e-4 {
+			t.Fatalf("rmsnorm mismatch at %d: got %f want %f", i, row[i], want)
+		}
+	}
+}
+
+func TestSoftmaxSumsToOne(t *testing.T) {
+	data := []float32{0.1, 1.2, -0.3, 0.4}
+	x := cpu.NewTensor(tensor.Shape{len(data)}, append([]float32(nil), data...))
+	if err := Softmax(x); err != nil {
+		t.Fatalf("softmax err: %v", err)
+	}
+	sum := float32(0)
+	for _, v := range x.DataFloat32() {
+		sum += v
+		if v <= 0 {
+			t.Fatalf("softmax produced non-positive prob %f", v)
+		}
+	}
+	if diff := absDiff(sum, 1); diff > 1e-5 {
+		t.Fatalf("softmax sum != 1: got %f", sum)
+	}
+}
+
+func TestRoPENoNaN(t *testing.T) {
+	headDim := 4
+	seq := 3
+	data := make([]float32, seq*headDim)
+	for i := range data {
+		data[i] = rand.Float32()
+	}
+	x := cpu.NewTensor(tensor.Shape{seq, headDim}, data)
+	positions := []int{0, 1, 2}
+	if err := RoPE(x, positions, headDim, 10000); err != nil {
+		t.Fatalf("rope err: %v", err)
+	}
+	for i, v := range x.DataFloat32() {
+		if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
+			t.Fatalf("rope produced invalid value at %d: %f", i, v)
+		}
+	}
+}
+
+func absDiff(a, b float32) float32 {
+	if a > b {
+		return a - b
+	}
+	return b - a
+}
+

+ 32 - 0
pkg/backend/cpu/nn/rmsnorm.go

@@ -0,0 +1,32 @@
+// Package nn provides neural network layer operations
+package nn
+
+import (
+	"math"
+
+	"makarna/pkg/backend/cpu"
+)
+
+// RMSNorm normalizes input in-place using RMS normalization
+// x: [batch, dim], w: [dim]
+func RMSNorm(x, w *cpu.Tensor, eps float32) error {
+	xData := x.DataFloat32()
+	wData := w.DataFloat32()
+
+	dim := w.Shape().NumElements()
+	numRows := x.Shape().NumElements() / dim
+
+	for i := 0; i < numRows; i++ {
+		row := xData[i*dim : (i+1)*dim]
+
+		// Sum of squares (uses SIMD dot when available)
+		ss := cpu.DotFloat32(row, row) / float32(dim)
+
+		// Normalize and scale
+		invRMS := 1.0 / float32(math.Sqrt(float64(ss+eps)))
+		for j := 0; j < dim; j++ {
+			row[j] = row[j] * invRMS * wData[j]
+		}
+	}
+	return nil
+}

+ 69 - 0
pkg/backend/cpu/nn/rope.go

@@ -0,0 +1,69 @@
+package nn
+
+import (
+	"math"
+
+	"makarna/pkg/backend/cpu"
+)
+
+// RoPE applies Rotary Positional Embeddings in-place
+// x: [seqLen, numHeads * headDim]
+// positions: position for each token in sequence (len = seqLen)
+// headDim: dimension of each attention head
+// theta: RoPE base frequency (typically 10000 for Llama, 1000000 for Qwen3)
+//
+// This uses the split-half rotation format (HuggingFace standard):
+// - Split head into [first_half, second_half]
+// - new_first = first * cos - second * sin
+// - new_second = second * cos + first * sin
+func RoPE(x *cpu.Tensor, positions []int, headDim int, theta float32) error {
+	data := x.DataFloat32()
+	shape := x.Shape()
+	seqLen := shape[0]
+	totalDim := shape[1] // numHeads * headDim
+	halfDim := headDim / 2
+
+	// Precompute inverse frequencies for the half-dimension once per call
+	invFreqs := make([]float64, halfDim)
+	for j := 0; j < halfDim; j++ {
+		invFreqs[j] = 1.0 / math.Pow(float64(theta), float64(2*j)/float64(headDim))
+	}
+
+	for seq := 0; seq < seqLen; seq++ {
+		pos := positions[seq]
+		rowStart := seq * totalDim
+
+		// Apply RoPE to each head
+		for headStart := 0; headStart < totalDim; headStart += headDim {
+			for j := 0; j < halfDim; j++ {
+				// Compute frequency: precomputed invFreq * position
+				freq := float64(pos) * invFreqs[j]
+				sin, cos := math.Sincos(freq)
+
+				// Split-half indexing: pair (j, j + halfDim)
+				idx0 := rowStart + headStart + j           // First half element
+				idx1 := rowStart + headStart + j + halfDim // Second half element
+
+				v0 := data[idx0] // first half value
+				v1 := data[idx1] // second half value
+
+				// Rotation:
+				// new_first = first * cos - second * sin
+				// new_second = second * cos + first * sin
+				data[idx0] = v0*float32(cos) - v1*float32(sin)
+				data[idx1] = v1*float32(cos) + v0*float32(sin)
+			}
+		}
+	}
+	return nil
+}
+
+// RoPESingle applies RoPE for a single position (for single token generation)
+func RoPESingle(x *cpu.Tensor, pos, headDim int, theta float32) error {
+	seqLen := x.Shape()[0]
+	positions := make([]int, seqLen)
+	for i := range positions {
+		positions[i] = pos + i
+	}
+	return RoPE(x, positions, headDim, theta)
+}

+ 97 - 0
pkg/backend/cpu/nn/self_attention.go

@@ -0,0 +1,97 @@
+package nn
+
+import (
+	"makarna/pkg/backend/cpu"
+	"makarna/pkg/backend/cpu/matmul"
+	"makarna/pkg/backend/cpu/ops"
+	"makarna/pkg/kvcache"
+	"makarna/pkg/tensor"
+)
+
+// SelfAttentionConfig holds parameters for self-attention computation
+type SelfAttentionConfig struct {
+	HeadDim    int
+	NumHeads   int
+	NumKVHeads int
+	RopeTheta  float32
+	RMSNormEps float32
+}
+
+// SelfAttention computes full self-attention with optional QK norm and KV cache
+// x: input tensor [seqLen, hiddenSize]
+// wq, wk, wv, wo: projection weights
+// qNorm, kNorm: optional QK normalization weights (can be nil)
+// positions: position indices for RoPE
+// cache: optional KV cache (can be nil)
+// layerIdx: layer index for cache
+func SelfAttention(
+	x *cpu.Tensor,
+	wq, wk, wv, wo *cpu.Tensor,
+	qNorm, kNorm *cpu.Tensor,
+	positions []int,
+	cfg SelfAttentionConfig,
+	cache kvcache.KVCacheInterface,
+	layerIdx int,
+) *cpu.Tensor {
+	seqLen := x.Shape()[0]
+	hiddenSize := wo.Shape()[0]
+
+	wqShape := wq.Shape()
+	wkShape := wk.Shape()
+	wvShape := wv.Shape()
+
+	// Q, K, V projections
+	xq := ops.Zeros(tensor.Shape{seqLen, wqShape[0]})
+	xk := ops.Zeros(tensor.Shape{seqLen, wkShape[0]})
+	xv := ops.Zeros(tensor.Shape{seqLen, wvShape[0]})
+
+	matmul.Linear(x, wq, xq)
+	matmul.Linear(x, wk, xk)
+	matmul.Linear(x, wv, xv)
+
+	// Apply QK norm if available (Qwen3)
+	if qNorm != nil {
+		RMSNorm(xq, qNorm, cfg.RMSNormEps)
+	}
+	if kNorm != nil {
+		RMSNorm(xk, kNorm, cfg.RMSNormEps)
+	}
+
+	// RoPE
+	RoPE(xq, positions, cfg.HeadDim, cfg.RopeTheta)
+	RoPE(xk, positions, cfg.HeadDim, cfg.RopeTheta)
+
+	numQHeads := wqShape[0] / cfg.HeadDim
+	numKVHeads := wkShape[0] / cfg.HeadDim
+	attnOut := ops.Zeros(tensor.Shape{seqLen, wqShape[0]})
+
+	// KV Cache Logic
+	if cache != nil {
+		views, startPos, err := cache.Append(layerIdx, xk, xv)
+		if err != nil {
+			panic(err)
+		}
+		if pv, ok := cache.(kvcache.PackedViewsProvider); ok {
+			pviews := pv.ViewsPacked(layerIdx)
+			if len(pviews) > 0 {
+				if err := CausalAttentionPackedBlocks(xq, pviews, attnOut, numQHeads, numKVHeads, cfg.HeadDim, startPos); err != nil {
+					panic(err)
+				}
+				goto done
+			}
+		}
+		if err := CausalAttentionBlocks(xq, views, attnOut, numQHeads, numKVHeads, cfg.HeadDim, startPos); err != nil {
+			panic(err)
+		}
+	} else {
+		CausalAttention(xq, xk, xv, attnOut, numQHeads, numKVHeads, cfg.HeadDim)
+	}
+
+done:
+
+	// Output projection
+	out := ops.Zeros(tensor.Shape{seqLen, hiddenSize})
+	matmul.Linear(attnOut, wo, out)
+
+	return out
+}

+ 72 - 0
pkg/backend/cpu/nn/silu.go

@@ -0,0 +1,72 @@
+package nn
+
+import (
+	"math"
+
+	"makarna/pkg/backend/cpu"
+)
+
+// SiLU applies x * sigmoid(x) in-place using the fastest available kernel.
+func SiLU(x *cpu.Tensor) error {
+	siluInplace(x.DataFloat32())
+	return nil
+}
+
+// SwiGLU: out = SiLU(gate) * up. Does not mutate gate.
+func SwiGLU(gate, up, out *cpu.Tensor) error {
+	gData := gate.DataFloat32()
+	uData := up.DataFloat32()
+	oData := out.DataFloat32()
+
+	if len(oData) == 0 {
+		return nil
+	}
+
+	if &gData[0] != &oData[0] {
+		copy(oData, gData)
+	}
+
+	siluInplace(oData)
+
+	for i := range oData {
+		oData[i] *= uData[i]
+	}
+	return nil
+}
+
+// siluInplace selects the SIMD kernel when available, falling back to scalar.
+func siluInplace(data []float32) {
+	if len(data) == 0 {
+		return
+	}
+
+	switch {
+	case hasSiLUAVX512 && cpu.SupportsAVX512():
+		main := len(data) &^ 15
+		if main > 0 {
+			siluAVX512Asm(&data[0], main)
+		}
+		if main == len(data) {
+			return
+		}
+		data = data[main:]
+	case hasSiLUAVX2 && cpu.SupportsAVX2():
+		main := len(data) &^ 7
+		if main > 0 {
+			siluAVX2Asm(&data[0], main)
+		}
+		if main == len(data) {
+			return
+		}
+		data = data[main:]
+	}
+
+	siluScalar(data)
+}
+
+func siluScalar(data []float32) {
+	for i := range data {
+		v := data[i]
+		data[i] = v / (1.0 + float32(math.Exp(float64(-v))))
+	}
+}

+ 15 - 0
pkg/backend/cpu/nn/silu_amd64.go

@@ -0,0 +1,15 @@
+//go:build amd64
+
+package nn
+
+const (
+	hasSiLUAVX2   = true
+	hasSiLUAVX512 = true
+)
+
+//go:noescape
+func siluAVX2Asm(x *float32, n int)
+
+//go:noescape
+func siluAVX512Asm(x *float32, n int)
+

+ 133 - 0
pkg/backend/cpu/nn/silu_avx2.s

@@ -0,0 +1,133 @@
+//go:build amd64
+// +build amd64
+
+#include "textflag.h"
+
+// func siluAVX2Asm(x *float32, n int)
+TEXT ·siluAVX2Asm(SB), NOSPLIT, $0-16
+	// Load args
+	MOVQ x+0(FP), DI
+	MOVQ n+8(FP), CX
+
+	CMPQ CX, $0
+	JLE done
+
+	// Broadcast constants
+	VBROADCASTSS ·expHi(SB), Y14
+	VBROADCASTSS ·expLo(SB), Y13
+	VBROADCASTSS ·log2EF(SB), Y12
+	VBROADCASTSS ·halfConst(SB), Y11
+	VBROADCASTSS ·expC1(SB), Y10
+	VBROADCASTSS ·expC2(SB), Y9
+	VBROADCASTSS ·oneConst(SB), Y8
+	VPBROADCASTD ·signMaskConst(SB), Y15
+
+loop:
+	CMPQ CX, $8
+	JL done
+
+	VMOVUPS (DI), Y0         // original x
+	VMOVAPS Y0, Y1           // copy for neg
+	VXORPS Y15, Y1, Y1       // y1 = -x
+
+	VMINPS Y14, Y1, Y1       // clamp hi
+	VMAXPS Y13, Y1, Y1       // clamp lo
+
+	VMULPS Y12, Y1, Y2       // y2 = x * log2e
+	VADDPS Y11, Y2, Y2       // +0.5
+	VROUNDPS $1, Y2, Y2      // floor
+
+	VCVTPS2DQ Y2, Y6         // integer exponent
+	VCVTDQ2PS Y6, Y5         // fx as float
+
+	VMULPS Y10, Y5, Y3       // fx * C1
+	VSUBPS Y3, Y1, Y1
+	VMULPS Y9, Y5, Y3        // fx * C2
+	VSUBPS Y3, Y1, Y1
+
+	VMULPS Y1, Y1, Y3        // z = x*x
+
+	VBROADCASTSS ·polyP0(SB), Y4
+	VMULPS Y1, Y4, Y4
+	VBROADCASTSS ·polyP1(SB), Y5
+	VADDPS Y5, Y4, Y4
+	VMULPS Y1, Y4, Y4
+	VBROADCASTSS ·polyP2(SB), Y5
+	VADDPS Y5, Y4, Y4
+	VMULPS Y1, Y4, Y4
+	VBROADCASTSS ·polyP3(SB), Y5
+	VADDPS Y5, Y4, Y4
+	VMULPS Y1, Y4, Y4
+	VBROADCASTSS ·polyP4(SB), Y5
+	VADDPS Y5, Y4, Y4
+	VMULPS Y1, Y4, Y4
+	VBROADCASTSS ·polyP5(SB), Y5
+	VADDPS Y5, Y4, Y4
+
+	VMULPS Y3, Y4, Y4        // y *= z
+	VADDPS Y1, Y4, Y4        // y += x
+	VADDPS Y8, Y4, Y4        // y += 1
+
+	VPBROADCASTD ·expBiasConst(SB), Y5
+	VPADDD Y5, Y6, Y6
+	VPSLLD $23, Y6, Y6
+	VMULPS Y6, Y4, Y4        // exp(-x)
+
+	VADDPS Y8, Y4, Y3        // denom = 1 + exp(-x)
+	VDIVPS Y3, Y8, Y3        // 1 / denom
+	VMULPS Y0, Y3, Y0        // x * sigmoid(x)
+	VMOVUPS Y0, (DI)
+
+	ADDQ $32, DI
+	SUBQ $8, CX
+	JMP loop
+
+done:
+	RET
+
+// Constants (single-float broadcast)
+DATA ·expHi+0(SB)/4, $0x42b0c0a5
+GLOBL ·expHi(SB), RODATA, $4
+
+DATA ·expLo+0(SB)/4, $0xc2b0c0a5
+GLOBL ·expLo(SB), RODATA, $4
+
+DATA ·log2EF+0(SB)/4, $0x3fb8aa3b
+GLOBL ·log2EF(SB), RODATA, $4
+
+DATA ·halfConst+0(SB)/4, $0x3f000000
+GLOBL ·halfConst(SB), RODATA, $4
+
+DATA ·expC1+0(SB)/4, $0x3f318000
+GLOBL ·expC1(SB), RODATA, $4
+
+DATA ·expC2+0(SB)/4, $0xb95e8083
+GLOBL ·expC2(SB), RODATA, $4
+
+DATA ·polyP0+0(SB)/4, $0x39506967
+GLOBL ·polyP0(SB), RODATA, $4
+
+DATA ·polyP1+0(SB)/4, $0x3ab743ce
+GLOBL ·polyP1(SB), RODATA, $4
+
+DATA ·polyP2+0(SB)/4, $0x3c088908
+GLOBL ·polyP2(SB), RODATA, $4
+
+DATA ·polyP3+0(SB)/4, $0x3d2aa9c1
+GLOBL ·polyP3(SB), RODATA, $4
+
+DATA ·polyP4+0(SB)/4, $0x3e2aaaaa
+GLOBL ·polyP4(SB), RODATA, $4
+
+DATA ·polyP5+0(SB)/4, $0x3f000000
+GLOBL ·polyP5(SB), RODATA, $4
+
+DATA ·oneConst+0(SB)/4, $0x3f800000
+GLOBL ·oneConst(SB), RODATA, $4
+
+DATA ·signMaskConst+0(SB)/4, $0x80000000
+GLOBL ·signMaskConst(SB), RODATA, $4
+
+DATA ·expBiasConst+0(SB)/4, $0x0000007f
+GLOBL ·expBiasConst(SB), RODATA, $4
+

+ 87 - 0
pkg/backend/cpu/nn/silu_avx512.s

@@ -0,0 +1,87 @@
+//go:build amd64
+// +build amd64
+
+#include "textflag.h"
+
+// func siluAVX512Asm(x *float32, n int)
+TEXT ·siluAVX512Asm(SB), NOSPLIT, $0-16
+	// Load args
+	MOVQ x+0(FP), DI
+	MOVQ n+8(FP), CX
+
+	CMPQ CX, $0
+	JLE done
+
+	// Broadcast constants
+	VBROADCASTSS ·expHi(SB), Z14
+	VBROADCASTSS ·expLo(SB), Z13
+	VBROADCASTSS ·log2EF(SB), Z12
+	VBROADCASTSS ·halfConst(SB), Z11
+	VBROADCASTSS ·expC1(SB), Z10
+	VBROADCASTSS ·expC2(SB), Z9
+	VBROADCASTSS ·oneConst(SB), Z8
+	VPBROADCASTD ·signMaskConst(SB), Z15
+
+loop:
+	CMPQ CX, $16
+	JL done
+
+	VMOVUPS (DI), Z0         // original x
+	VMOVAPS Z0, Z1           // copy for neg
+	VXORPS Z15, Z1, Z1       // z1 = -x
+
+	VMINPS Z14, Z1, Z1       // clamp hi
+	VMAXPS Z13, Z1, Z1       // clamp lo
+
+	VMULPS Z12, Z1, Z2       // z2 = x * log2e
+	VADDPS Z11, Z2, Z2       // +0.5
+	VRNDSCALEPS $1, Z2, Z2   // floor
+
+	VCVTPS2DQ Z2, Z6         // integer exponent
+	VCVTDQ2PS Z6, Z5         // fx as float
+
+	VMULPS Z10, Z5, Z3       // fx * C1
+	VSUBPS Z3, Z1, Z1
+	VMULPS Z9, Z5, Z3        // fx * C2
+	VSUBPS Z3, Z1, Z1
+
+	VMULPS Z1, Z1, Z3        // z = x*x
+
+	VBROADCASTSS ·polyP0(SB), Z4
+	VMULPS Z1, Z4, Z4
+	VBROADCASTSS ·polyP1(SB), Z5
+	VADDPS Z5, Z4, Z4
+	VMULPS Z1, Z4, Z4
+	VBROADCASTSS ·polyP2(SB), Z5
+	VADDPS Z5, Z4, Z4
+	VMULPS Z1, Z4, Z4
+	VBROADCASTSS ·polyP3(SB), Z5
+	VADDPS Z5, Z4, Z4
+	VMULPS Z1, Z4, Z4
+	VBROADCASTSS ·polyP4(SB), Z5
+	VADDPS Z5, Z4, Z4
+	VMULPS Z1, Z4, Z4
+	VBROADCASTSS ·polyP5(SB), Z5
+	VADDPS Z5, Z4, Z4
+
+	VMULPS Z3, Z4, Z4        // y *= z
+	VADDPS Z1, Z4, Z4        // y += x
+	VADDPS Z8, Z4, Z4        // y += 1
+
+	VPBROADCASTD ·expBiasConst(SB), Z5
+	VPADDD Z5, Z6, Z6
+	VPSLLD $23, Z6, Z6
+	VMULPS Z6, Z4, Z4        // exp(-x)
+
+	VADDPS Z8, Z4, Z3        // denom = 1 + exp(-x)
+	VDIVPS Z3, Z8, Z3        // 1 / denom
+	VMULPS Z0, Z3, Z0        // x * sigmoid(x)
+	VMOVUPS Z0, (DI)
+
+	ADDQ $64, DI
+	SUBQ $16, CX
+	JMP loop
+
+done:
+	RET
+

+ 12 - 0
pkg/backend/cpu/nn/silu_noasm.go

@@ -0,0 +1,12 @@
+//go:build !amd64
+
+package nn
+
+const (
+	hasSiLUAVX2   = false
+	hasSiLUAVX512 = false
+)
+
+func siluAVX2Asm(_ *float32, _ int)   {}
+func siluAVX512Asm(_ *float32, _ int) {}
+

+ 72 - 0
pkg/backend/cpu/nn/softmax.go

@@ -0,0 +1,72 @@
+package nn
+
+import (
+	"errors"
+	"math"
+
+	"makarna/pkg/backend/cpu"
+)
+
+// Softmax applies softmax normalization in-place.
+// Dispatch order: AVX-512 -> AVX2 -> scalar fallback.
+func Softmax(x *cpu.Tensor) error {
+	data := x.DataFloat32()
+	if len(data) == 0 {
+		return nil
+	}
+
+	softmaxInplace(data)
+	var sum float32
+	for _, v := range data {
+		if v < 0 || math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
+			return errors.New("softmax produced invalid value")
+		}
+		sum += v
+	}
+	if sum <= 0 || math.IsNaN(float64(sum)) || math.IsInf(float64(sum), 0) {
+		return errors.New("softmax produced invalid sum")
+	}
+	d := float32(math.Abs(float64(sum - 1)))
+	if d > 1e-3 {
+		inv := 1 / sum
+		for i := range data {
+			data[i] *= inv
+		}
+	}
+	return nil
+}
+
+func softmaxInplace(data []float32) {
+	if len(data) == 0 {
+		return
+	}
+	switch {
+	case hasSoftmaxAVX512 && cpu.SupportsAVX512() && len(data) >= 16:
+		softmaxAVX512(data)
+	case hasSoftmaxAVX2 && cpu.SupportsAVX2() && len(data) >= 8:
+		softmaxAVX2(data)
+	default:
+		softmaxScalar(data)
+	}
+}
+
+func softmaxScalar(data []float32) {
+	maxVal := float32(-math.MaxFloat32)
+	for _, v := range data {
+		if v > maxVal {
+			maxVal = v
+		}
+	}
+
+	var sum float32
+	for i, v := range data {
+		ev := float32(math.Exp(float64(v - maxVal)))
+		data[i] = ev
+		sum += ev
+	}
+
+	inv := 1.0 / sum
+	for i := range data {
+		data[i] *= inv
+	}
+}

+ 94 - 0
pkg/backend/cpu/nn/softmax_amd64.go

@@ -0,0 +1,94 @@
+//go:build amd64
+
+package nn
+
+import "math"
+
+const (
+	hasSoftmaxAVX2   = true
+	hasSoftmaxAVX512 = true
+)
+
+//go:noescape
+func softmaxMaxAVX2Asm(x *float32, n int) float32
+
+//go:noescape
+func softmaxExpSumAVX2Asm(x *float32, n int, max float32) float32
+
+//go:noescape
+func softmaxScaleAVX2Asm(x *float32, n int, inv float32)
+
+//go:noescape
+func softmaxMaxAVX512Asm(x *float32, n int) float32
+
+//go:noescape
+func softmaxExpSumAVX512Asm(x *float32, n int, max float32) float32
+
+//go:noescape
+func softmaxScaleAVX512Asm(x *float32, n int, inv float32)
+
+func softmaxAVX2(data []float32) {
+	n := len(data)
+	main := n &^ 7 // process multiples of 8
+
+	maxVal := float32(-math.MaxFloat32)
+	if main > 0 {
+		maxVal = softmaxMaxAVX2Asm(&data[0], main)
+	}
+	for i := main; i < n; i++ {
+		if data[i] > maxVal {
+			maxVal = data[i]
+		}
+	}
+
+	var sum float32
+	if main > 0 {
+		sum += softmaxExpSumAVX2Asm(&data[0], main, maxVal)
+	}
+	for i := main; i < n; i++ {
+		ev := float32(math.Exp(float64(data[i] - maxVal)))
+		data[i] = ev
+		sum += ev
+	}
+
+	inv := 1.0 / sum
+	if main > 0 {
+		softmaxScaleAVX2Asm(&data[0], main, inv)
+	}
+	for i := main; i < n; i++ {
+		data[i] *= inv
+	}
+}
+
+func softmaxAVX512(data []float32) {
+	n := len(data)
+	main := n &^ 15 // process multiples of 16
+
+	maxVal := float32(-math.MaxFloat32)
+	if main > 0 {
+		maxVal = softmaxMaxAVX512Asm(&data[0], main)
+	}
+	for i := main; i < n; i++ {
+		if data[i] > maxVal {
+			maxVal = data[i]
+		}
+	}
+
+	var sum float32
+	if main > 0 {
+		sum += softmaxExpSumAVX512Asm(&data[0], main, maxVal)
+	}
+	for i := main; i < n; i++ {
+		ev := float32(math.Exp(float64(data[i] - maxVal)))
+		data[i] = ev
+		sum += ev
+	}
+
+	inv := 1.0 / sum
+	if main > 0 {
+		softmaxScaleAVX512Asm(&data[0], main, inv)
+	}
+	for i := main; i < n; i++ {
+		data[i] *= inv
+	}
+}

+ 160 - 0
pkg/backend/cpu/nn/softmax_avx2.s

@@ -0,0 +1,160 @@
+//go:build amd64
+// +build amd64
+
+#include "textflag.h"
+
+// func softmaxMaxAVX2Asm(x *float32, n int) float32
+TEXT ·softmaxMaxAVX2Asm(SB), NOSPLIT, $0-24
+	MOVQ x+0(FP), DI
+	MOVQ n+8(FP), CX
+
+	VBROADCASTSS ·negInfConst(SB), Y0
+
+loop_max:
+	CMPQ CX, $8
+	JL max_reduce
+	VMOVUPS (DI), Y1
+	VMAXPS Y1, Y0, Y0
+	ADDQ $32, DI
+	SUBQ $8, CX
+	JMP loop_max
+
+max_reduce:
+	VEXTRACTF128 $0, Y0, X0
+	VEXTRACTF128 $1, Y0, X1
+	VMAXPS X1, X0, X0
+	VPERMILPS $0x4E, X0, X1
+	VMAXPS X1, X0, X0
+	VPERMILPS $0xB1, X0, X1
+	VMAXPS X1, X0, X0
+
+	// tail
+	TESTQ CX, CX
+	JE max_done
+max_tail:
+	VMOVSS (DI), X1
+	VMAXSS X1, X0, X0
+	ADDQ $4, DI
+	DECQ CX
+	JNZ max_tail
+
+max_done:
+	MOVSS X0, ret+16(FP)
+	RET
+
+// func softmaxExpSumAVX2Asm(x *float32, n int, max float32) float32
+TEXT ·softmaxExpSumAVX2Asm(SB), NOSPLIT, $0-28
+	MOVQ x+0(FP), DI
+	MOVQ n+8(FP), CX
+	MOVSS max+16(FP), X5
+
+	VBROADCASTSS X5, Y5          // max
+	VBROADCASTSS ·expHi(SB), Y14
+	VBROADCASTSS ·expLo(SB), Y13
+	VBROADCASTSS ·log2EF(SB), Y12
+	VBROADCASTSS ·halfConst(SB), Y11
+	VBROADCASTSS ·expC1(SB), Y10
+	VBROADCASTSS ·expC2(SB), Y9
+	VBROADCASTSS ·oneConst(SB), Y8
+	VPBROADCASTD ·expBiasConst(SB), Y15
+
+	VXORPS Y7, Y7, Y7            // sum accumulator
+
+loop_exp:
+	CMPQ CX, $8
+	JL exp_reduce
+
+	VMOVUPS (DI), Y1
+	VSUBPS Y5, Y1, Y1            // x - max
+	VMINPS Y14, Y1, Y1
+	VMAXPS Y13, Y1, Y1
+
+	VMULPS Y12, Y1, Y2
+	VADDPS Y11, Y2, Y2
+	VROUNDPS $1, Y2, Y2
+
+	VCVTPS2DQ Y2, Y6
+	VCVTDQ2PS Y6, Y3
+
+	VMULPS Y10, Y3, Y4
+	VSUBPS Y4, Y1, Y1
+	VMULPS Y9, Y3, Y4
+	VSUBPS Y4, Y1, Y1
+
+	VMULPS Y1, Y1, Y4            // z = x*x
+
+	VBROADCASTSS ·polyP0(SB), Y0
+	VMULPS Y1, Y0, Y0
+	VBROADCASTSS ·polyP1(SB), Y3
+	VADDPS Y3, Y0, Y0
+	VMULPS Y1, Y0, Y0
+	VBROADCASTSS ·polyP2(SB), Y3
+	VADDPS Y3, Y0, Y0
+	VMULPS Y1, Y0, Y0
+	VBROADCASTSS ·polyP3(SB), Y3
+	VADDPS Y3, Y0, Y0
+	VMULPS Y1, Y0, Y0
+	VBROADCASTSS ·polyP4(SB), Y3
+	VADDPS Y3, Y0, Y0
+	VMULPS Y1, Y0, Y0
+	VBROADCASTSS ·polyP5(SB), Y3
+	VADDPS Y3, Y0, Y0
+
+	VMULPS Y4, Y0, Y0            // y *= z
+	VADDPS Y1, Y0, Y0            // y += x
+	VADDPS Y8, Y0, Y0            // y += 1
+
+	VPADDD Y15, Y6, Y6
+	VPSLLD $23, Y6, Y6
+	VMULPS Y6, Y0, Y0            // exp(x - max)
+
+	VADDPS Y0, Y7, Y7
+	VMOVUPS Y0, (DI)
+
+	ADDQ $32, DI
+	SUBQ $8, CX
+	JMP loop_exp
+
+exp_reduce:
+	VEXTRACTF128 $0, Y7, X0
+	VEXTRACTF128 $1, Y7, X1
+	VADDPS X1, X0, X0
+	VHADDPS X0, X0, X0
+	VHADDPS X0, X0, X0
+
+	MOVSS X0, ret+24(FP)
+	RET
+
+// func softmaxScaleAVX2Asm(x *float32, n int, inv float32)
+TEXT ·softmaxScaleAVX2Asm(SB), NOSPLIT, $0-24
+	MOVQ x+0(FP), DI
+	MOVQ n+8(FP), CX
+	MOVSS inv+16(FP), X1
+	VBROADCASTSS X1, Y1
+
+loop_scale:
+	CMPQ CX, $8
+	JL scale_tail
+	VMULPS (DI), Y1, Y0
+	VMOVUPS Y0, (DI)
+	ADDQ $32, DI
+	SUBQ $8, CX
+	JMP loop_scale
+
+scale_tail:
+	TESTQ CX, CX
+	JE scale_done
+scale_tail_loop:
+	MOVSS (DI), X0
+	VMULSS X1, X0, X0
+	MOVSS X0, (DI)
+	ADDQ $4, DI
+	DECQ CX
+	JNZ scale_tail_loop
+
+scale_done:
+	RET
+
+// Additional constants
+DATA ·negInfConst+0(SB)/4, $0xff800000
+GLOBL ·negInfConst(SB), RODATA, $4

+ 158 - 0
pkg/backend/cpu/nn/softmax_avx512.s

@@ -0,0 +1,158 @@
+//go:build amd64
+// +build amd64
+
+#include "textflag.h"
+
+// func softmaxMaxAVX512Asm(x *float32, n int) float32
+TEXT ·softmaxMaxAVX512Asm(SB), NOSPLIT, $0-24
+	MOVQ x+0(FP), DI
+	MOVQ n+8(FP), CX
+
+	VBROADCASTSS ·negInfConst(SB), Z0
+
+loop_max:
+	CMPQ CX, $16
+	JL max_reduce
+	VMOVUPS (DI), Z1
+	VMAXPS Z1, Z0, Z0
+	ADDQ $64, DI
+	SUBQ $16, CX
+	JMP loop_max
+
+max_reduce:
+	VEXTRACTF32X8 $1, Z0, Y1
+	VMAXPS Y1, Y0, Y0
+	VEXTRACTF128 $1, Y0, X1
+	VMAXPS X1, X0, X0
+	VPERMILPS $0x4E, X0, X1
+	VMAXPS X1, X0, X0
+	VPERMILPS $0xB1, X0, X1
+	VMAXPS X1, X0, X0
+
+	// tail
+	TESTQ CX, CX
+	JE max_done
+max_tail:
+	VMOVSS (DI), X1
+	VMAXSS X1, X0, X0
+	ADDQ $4, DI
+	DECQ CX
+	JNZ max_tail
+
+max_done:
+	MOVSS X0, ret+16(FP)
+	RET
+
+// func softmaxExpSumAVX512Asm(x *float32, n int, max float32) float32
+TEXT ·softmaxExpSumAVX512Asm(SB), NOSPLIT, $0-28
+	MOVQ x+0(FP), DI
+	MOVQ n+8(FP), CX
+	MOVSS max+16(FP), X5
+
+	VBROADCASTSS X5, Z5          // max
+	VBROADCASTSS ·expHi(SB), Z14
+	VBROADCASTSS ·expLo(SB), Z13
+	VBROADCASTSS ·log2EF(SB), Z12
+	VBROADCASTSS ·halfConst(SB), Z11
+	VBROADCASTSS ·expC1(SB), Z10
+	VBROADCASTSS ·expC2(SB), Z9
+	VBROADCASTSS ·oneConst(SB), Z8
+	VPBROADCASTD ·expBiasConst(SB), Z15
+
+	VXORPS Z7, Z7, Z7            // sum accumulator
+
+loop_exp:
+	CMPQ CX, $16
+	JL exp_reduce
+
+	VMOVUPS (DI), Z1
+	VSUBPS Z5, Z1, Z1            // x - max
+	VMINPS Z14, Z1, Z1
+	VMAXPS Z13, Z1, Z1
+
+	VMULPS Z12, Z1, Z2
+	VADDPS Z11, Z2, Z2
+	VRNDSCALEPS $1, Z2, Z2
+
+	VCVTPS2DQ Z2, Z6
+	VCVTDQ2PS Z6, Z3
+
+	VMULPS Z10, Z3, Z4
+	VSUBPS Z4, Z1, Z1
+	VMULPS Z9, Z3, Z4
+	VSUBPS Z4, Z1, Z1
+
+	VMULPS Z1, Z1, Z4            // z = x*x
+
+	VBROADCASTSS ·polyP0(SB), Z0
+	VMULPS Z1, Z0, Z0
+	VBROADCASTSS ·polyP1(SB), Z3
+	VADDPS Z3, Z0, Z0
+	VMULPS Z1, Z0, Z0
+	VBROADCASTSS ·polyP2(SB), Z3
+	VADDPS Z3, Z0, Z0
+	VMULPS Z1, Z0, Z0
+	VBROADCASTSS ·polyP3(SB), Z3
+	VADDPS Z3, Z0, Z0
+	VMULPS Z1, Z0, Z0
+	VBROADCASTSS ·polyP4(SB), Z3
+	VADDPS Z3, Z0, Z0
+	VMULPS Z1, Z0, Z0
+	VBROADCASTSS ·polyP5(SB), Z3
+	VADDPS Z3, Z0, Z0
+
+	VMULPS Z4, Z0, Z0            // y *= z
+	VADDPS Z1, Z0, Z0            // y += x
+	VADDPS Z8, Z0, Z0            // y += 1
+
+	VPADDD Z15, Z6, Z6
+	VPSLLD $23, Z6, Z6
+	VMULPS Z6, Z0, Z0            // exp(x - max)
+
+	VADDPS Z0, Z7, Z7
+	VMOVUPS Z0, (DI)
+
+	ADDQ $64, DI
+	SUBQ $16, CX
+	JMP loop_exp
+
+exp_reduce:
+	VEXTRACTF32X8 $1, Z7, Y1
+	VADDPS Y1, Y0, Y0
+	VEXTRACTF128 $1, Y0, X1
+	VADDPS X1, X0, X0
+	VHADDPS X0, X0, X0
+	VHADDPS X0, X0, X0
+
+	MOVSS X0, ret+24(FP)
+	RET
+
+// func softmaxScaleAVX512Asm(x *float32, n int, inv float32)
+TEXT ·softmaxScaleAVX512Asm(SB), NOSPLIT, $0-24
+	MOVQ x+0(FP), DI
+	MOVQ n+8(FP), CX
+	MOVSS inv+16(FP), X1
+	VBROADCASTSS X1, Z1
+
+loop_scale:
+	CMPQ CX, $16
+	JL scale_tail
+	VMULPS (DI), Z1, Z0
+	VMOVUPS Z0, (DI)
+	ADDQ $64, DI
+	SUBQ $16, CX
+	JMP loop_scale
+
+scale_tail:
+	TESTQ CX, CX
+	JE scale_done
+scale_tail_loop:
+	MOVSS (DI), X0
+	VMULSS X1, X0, X0
+	MOVSS X0, (DI)
+	ADDQ $4, DI
+	DECQ CX
+	JNZ scale_tail_loop
+
+scale_done:
+	RET

+ 11 - 0
pkg/backend/cpu/nn/softmax_noasm.go

@@ -0,0 +1,11 @@
+//go:build !amd64
+
+package nn
+
+const (
+	hasSoftmaxAVX2   = false
+	hasSoftmaxAVX512 = false
+)
+
+func softmaxAVX2(data []float32)   { softmaxScalar(data) }
+func softmaxAVX512(data []float32) { softmaxScalar(data) }

+ 21 - 0
pkg/backend/cpu/nn/tensor_utils.go

@@ -0,0 +1,21 @@
+package nn
+
+import (
+	"fmt"
+
+	"makarna/pkg/backend/cpu"
+)
+
+func FlattenVector(t *cpu.Tensor, n int, name string) ([]float32, error) {
+	if t == nil {
+		return nil, fmt.Errorf("missing %s", name)
+	}
+	data := t.DataFloat32()
+	if len(data) == n {
+		return data, nil
+	}
+	if len(data) >= n {
+		return data[:n], nil
+	}
+	return nil, fmt.Errorf("%s has unexpected size %d", name, len(data))
+}

+ 39 - 0
pkg/backend/cpu/nn/transformer.go

@@ -0,0 +1,39 @@
+package nn
+
+import (
+	"makarna/pkg/backend/cpu"
+	"makarna/pkg/backend/cpu/ops"
+)
+
+// TransformerBlock applies a standard pre-norm transformer block:
+// 1. RMSNorm → Attention → Add residual
+// 2. RMSNorm → MLP → Add residual
+//
+// attnFn: function that computes attention given normalized input
+// mlpFn: function that computes MLP given normalized input
+func TransformerBlock(
+	hidden *cpu.Tensor,
+	attnNorm, mlpNorm *cpu.Tensor,
+	eps float32,
+	attnFn func(*cpu.Tensor) *cpu.Tensor,
+	mlpFn func(*cpu.Tensor) *cpu.Tensor,
+) *cpu.Tensor {
+	// Save residual
+	residual := ops.Zeros(hidden.Shape())
+	ops.Copy(residual, hidden)
+
+	// Attention sub-block
+	RMSNorm(hidden, attnNorm, eps)
+	attnOut := attnFn(hidden)
+	ops.Add(residual, attnOut)
+	ops.Copy(hidden, residual)
+
+	// MLP sub-block
+	ops.Copy(residual, hidden)
+	RMSNorm(hidden, mlpNorm, eps)
+	mlpOut := mlpFn(hidden)
+	ops.Add(residual, mlpOut)
+	ops.Copy(hidden, residual)
+
+	return hidden
+}

+ 17 - 0
pkg/backend/cpu/ops/add.go

@@ -0,0 +1,17 @@
+// Package ops provides basic tensor operations for CPU backend
+package ops
+
+import (
+	"makarna/pkg/backend/cpu"
+)
+
+// Add performs element-wise addition: dst += src
+func Add(dst, src *cpu.Tensor) error {
+	dstData := dst.DataFloat32()
+	srcData := src.DataFloat32()
+	
+	for i := range dstData {
+		dstData[i] += srcData[i]
+	}
+	return nil
+}

+ 11 - 0
pkg/backend/cpu/ops/copy.go

@@ -0,0 +1,11 @@
+package ops
+
+import "makarna/pkg/backend/cpu"
+
+// Copy copies data from src to dst
+func Copy(dst, src *cpu.Tensor) error {
+	dstData := dst.DataFloat32()
+	srcData := src.DataFloat32()
+	copy(dstData, srcData)
+	return nil
+}

+ 42 - 0
pkg/backend/cpu/ops/factory.go

@@ -0,0 +1,42 @@
+package ops
+
+import (
+	"makarna/pkg/backend/cpu"
+	"makarna/pkg/tensor"
+)
+
+// Zeros creates a new zero-filled tensor
+func Zeros(shape tensor.Shape) *cpu.Tensor {
+	return cpu.NewTensor(shape, nil) // Go zeros by default
+}
+
+// Ones creates a new tensor filled with 1.0
+func Ones(shape tensor.Shape) *cpu.Tensor {
+	data := make([]float32, shape.NumElements())
+	for i := range data {
+		data[i] = 1.0
+	}
+	return cpu.NewTensor(shape, data)
+}
+
+// Full creates a new tensor filled with a specific value
+func Full(shape tensor.Shape, value float32) *cpu.Tensor {
+	data := make([]float32, shape.NumElements())
+	for i := range data {
+		data[i] = value
+	}
+	return cpu.NewTensor(shape, data)
+}
+
+// Arange creates a 1D tensor with values [start, start+step, start+2*step, ...)
+func Arange(start, end, step float32) *cpu.Tensor {
+	n := int((end - start) / step)
+	if n <= 0 {
+		return cpu.NewTensor(tensor.Shape{0}, nil)
+	}
+	data := make([]float32, n)
+	for i := 0; i < n; i++ {
+		data[i] = start + float32(i)*step
+	}
+	return cpu.NewTensor(tensor.Shape{n}, data)
+}

+ 14 - 0
pkg/backend/cpu/ops/mul.go

@@ -0,0 +1,14 @@
+package ops
+
+import "makarna/pkg/backend/cpu"
+
+// Mul performs element-wise multiplication: dst *= src
+func Mul(dst, src *cpu.Tensor) error {
+	dstData := dst.DataFloat32()
+	srcData := src.DataFloat32()
+	
+	for i := range dstData {
+		dstData[i] *= srcData[i]
+	}
+	return nil
+}

+ 58 - 0
pkg/backend/cpu/ops/permute.go

@@ -0,0 +1,58 @@
+package ops
+
+import (
+	"makarna/pkg/backend/cpu"
+	"makarna/pkg/tensor"
+)
+
+// Permute reorders tensor dimensions
+// Example: Permute(t, 1, 0, 2) swaps first two dims
+// Returns a new tensor (not a view - copies data for simplicity)
+func Permute(t *cpu.Tensor, order ...int) *cpu.Tensor {
+	oldShape := t.Shape()
+	oldData := t.DataFloat32()
+	
+	// Compute new shape
+	newShape := make(tensor.Shape, len(order))
+	for i, o := range order {
+		newShape[i] = oldShape[o]
+	}
+	
+	// Compute old strides
+	oldStrides := make([]int, len(oldShape))
+	oldStrides[len(oldShape)-1] = 1
+	for i := len(oldShape) - 2; i >= 0; i-- {
+		oldStrides[i] = oldStrides[i+1] * oldShape[i+1]
+	}
+	
+	// Compute new strides  
+	newStrides := make([]int, len(newShape))
+	newStrides[len(newShape)-1] = 1
+	for i := len(newShape) - 2; i >= 0; i-- {
+		newStrides[i] = newStrides[i+1] * newShape[i+1]
+	}
+	
+	newData := make([]float32, newShape.NumElements())
+	
+	// Iterate over all new indices and map to old
+	indices := make([]int, len(newShape))
+	for i := 0; i < len(newData); i++ {
+		// Compute old flat index
+		oldFlatIdx := 0
+		for d := 0; d < len(order); d++ {
+			oldFlatIdx += indices[d] * oldStrides[order[d]]
+		}
+		newData[i] = oldData[oldFlatIdx]
+		
+		// Increment indices
+		for d := len(indices) - 1; d >= 0; d-- {
+			indices[d]++
+			if indices[d] < newShape[d] {
+				break
+			}
+			indices[d] = 0
+		}
+	}
+	
+	return cpu.NewTensor(newShape, newData)
+}

+ 95 - 0
pkg/backend/cpu/ops/repeat.go

@@ -0,0 +1,95 @@
+package ops
+
+import (
+	"makarna/pkg/backend/cpu"
+	"makarna/pkg/tensor"
+)
+
+// Repeat repeats tensor n times along dimension dim
+func Repeat(t *cpu.Tensor, dim, n int) *cpu.Tensor {
+	shape := t.Shape()
+	data := t.DataFloat32()
+	
+	newShape := make(tensor.Shape, len(shape))
+	copy(newShape, shape)
+	newShape[dim] *= n
+	
+	// Simple case: repeat along first dimension
+	if dim == 0 {
+		newData := make([]float32, 0, newShape.NumElements())
+		for i := 0; i < n; i++ {
+			newData = append(newData, data...)
+		}
+		return cpu.NewTensor(newShape, newData)
+	}
+	
+	// Simple case: repeat along last dimension
+	if dim == len(shape)-1 {
+		newData := make([]float32, newShape.NumElements())
+		rowSize := shape[dim]
+		newRowSize := newShape[dim]
+		numRows := len(data) / rowSize
+		
+		for row := 0; row < numRows; row++ {
+			srcStart := row * rowSize
+			dstStart := row * newRowSize
+			for rep := 0; rep < n; rep++ {
+				copy(newData[dstStart+rep*rowSize:], data[srcStart:srcStart+rowSize])
+			}
+		}
+		return cpu.NewTensor(newShape, newData)
+	}
+	
+	// General case: repeat along middle dimension
+	// Calculate outer (before dim), inner (after dim) sizes
+	outerSize := 1
+	for i := 0; i < dim; i++ {
+		outerSize *= shape[i]
+	}
+	innerSize := 1
+	for i := dim + 1; i < len(shape); i++ {
+		innerSize *= shape[i]
+	}
+	dimSize := shape[dim]
+	sliceSize := dimSize * innerSize
+	
+	newData := make([]float32, newShape.NumElements())
+	dstIdx := 0
+	
+	for outer := 0; outer < outerSize; outer++ {
+		srcStart := outer * sliceSize
+		srcSlice := data[srcStart : srcStart+sliceSize]
+		for rep := 0; rep < n; rep++ {
+			copy(newData[dstIdx:], srcSlice)
+			dstIdx += sliceSize
+		}
+	}
+	
+	return cpu.NewTensor(newShape, newData)
+}
+
+// RepeatInterleave repeats each element n times along dimension
+func RepeatInterleave(t *cpu.Tensor, dim, n int) *cpu.Tensor {
+	shape := t.Shape()
+	data := t.DataFloat32()
+	
+	newShape := make(tensor.Shape, len(shape))
+	copy(newShape, shape)
+	newShape[dim] *= n
+	
+	// Handle 2D case with dim=0
+	if len(shape) == 2 && dim == 0 {
+		newData := make([]float32, newShape.NumElements())
+		rowSize := shape[1]
+		for row := 0; row < shape[0]; row++ {
+			srcRow := data[row*rowSize : (row+1)*rowSize]
+			for rep := 0; rep < n; rep++ {
+				dstStart := (row*n + rep) * rowSize
+				copy(newData[dstStart:], srcRow)
+			}
+		}
+		return cpu.NewTensor(newShape, newData)
+	}
+	
+	return nil
+}

+ 45 - 0
pkg/backend/cpu/ops/reshape.go

@@ -0,0 +1,45 @@
+package ops
+
+import (
+	"makarna/pkg/backend/cpu"
+	"makarna/pkg/tensor"
+)
+
+// Reshape returns a new tensor with different shape (same data)
+func Reshape(t *cpu.Tensor, shape tensor.Shape) *cpu.Tensor {
+	return cpu.NewTensor(shape, t.DataFloat32())
+}
+
+// View is an alias for Reshape
+func View(t *cpu.Tensor, shape tensor.Shape) *cpu.Tensor {
+	return Reshape(t, shape)
+}
+
+// Squeeze removes dimensions of size 1
+func Squeeze(t *cpu.Tensor) *cpu.Tensor {
+	oldShape := t.Shape()
+	newShape := make(tensor.Shape, 0)
+	for _, s := range oldShape {
+		if s != 1 {
+			newShape = append(newShape, s)
+		}
+	}
+	if len(newShape) == 0 {
+		newShape = tensor.Shape{1} // Scalar
+	}
+	return cpu.NewTensor(newShape, t.DataFloat32())
+}
+
+// Unsqueeze adds a dimension of size 1 at the specified position
+func Unsqueeze(t *cpu.Tensor, dim int) *cpu.Tensor {
+	oldShape := t.Shape()
+	newShape := make(tensor.Shape, len(oldShape)+1)
+	for i := 0; i < dim; i++ {
+		newShape[i] = oldShape[i]
+	}
+	newShape[dim] = 1
+	for i := dim; i < len(oldShape); i++ {
+		newShape[i+1] = oldShape[i]
+	}
+	return cpu.NewTensor(newShape, t.DataFloat32())
+}

+ 49 - 0
pkg/backend/cpu/ops/scalar.go

@@ -0,0 +1,49 @@
+package ops
+
+import "makarna/pkg/backend/cpu"
+
+// Scale multiplies all elements by a scalar (in-place)
+func Scale(t *cpu.Tensor, s float32) {
+	data := t.DataFloat32()
+	for i := range data {
+		data[i] *= s
+	}
+}
+
+// AddScalar adds a scalar to all elements (in-place)
+func AddScalar(t *cpu.Tensor, s float32) {
+	data := t.DataFloat32()
+	for i := range data {
+		data[i] += s
+	}
+}
+
+// Clamp clamps all values to [min, max] (in-place)
+func Clamp(t *cpu.Tensor, min, max float32) {
+	data := t.DataFloat32()
+	for i := range data {
+		if data[i] < min {
+			data[i] = min
+		} else if data[i] > max {
+			data[i] = max
+		}
+	}
+}
+
+// Neg negates all elements (in-place)
+func Neg(t *cpu.Tensor) {
+	data := t.DataFloat32()
+	for i := range data {
+		data[i] = -data[i]
+	}
+}
+
+// Abs takes absolute value of all elements (in-place)
+func Abs(t *cpu.Tensor) {
+	data := t.DataFloat32()
+	for i := range data {
+		if data[i] < 0 {
+			data[i] = -data[i]
+		}
+	}
+}

+ 121 - 0
pkg/backend/cpu/ops/slice.go

@@ -0,0 +1,121 @@
+package ops
+
+import (
+	"makarna/pkg/backend/cpu"
+	"makarna/pkg/tensor"
+)
+
+// Slice extracts a portion of tensor along specified dimension
+// dim: dimension to slice
+// start, end: range [start, end)
+func Slice(t *cpu.Tensor, dim, start, end int) *cpu.Tensor {
+	shape := t.Shape()
+	data := t.DataFloat32()
+	
+	newShape := make(tensor.Shape, len(shape))
+	copy(newShape, shape)
+	newShape[dim] = end - start
+	
+	// Calculate strides
+	strides := make([]int, len(shape))
+	strides[len(shape)-1] = 1
+	for i := len(shape) - 2; i >= 0; i-- {
+		strides[i] = strides[i+1] * shape[i+1]
+	}
+	
+	newData := make([]float32, newShape.NumElements())
+	
+	// Iterate and copy
+	newIdx := 0
+	indices := make([]int, len(shape))
+	for i := 0; i < len(data); i++ {
+		// Check if this index is within slice range for target dim
+		if indices[dim] >= start && indices[dim] < end {
+			newData[newIdx] = data[i]
+			newIdx++
+		}
+		
+		// Increment indices
+		for d := len(indices) - 1; d >= 0; d-- {
+			indices[d]++
+			if indices[d] < shape[d] {
+				break
+			}
+			indices[d] = 0
+		}
+	}
+	
+	return cpu.NewTensor(newShape, newData)
+}
+
+// Concat concatenates tensors along specified dimension
+func Concat(tensors []*cpu.Tensor, dim int) *cpu.Tensor {
+	if len(tensors) == 0 {
+		return nil
+	}
+	if len(tensors) == 1 {
+		return tensors[0]
+	}
+	
+	// Calculate new shape
+	refShape := tensors[0].Shape()
+	newShape := make(tensor.Shape, len(refShape))
+	copy(newShape, refShape)
+	
+	totalDim := 0
+	for _, t := range tensors {
+		totalDim += t.Shape()[dim]
+	}
+	newShape[dim] = totalDim
+	
+	// Simple case: concat along last dimension
+	if dim == len(refShape)-1 {
+		newData := make([]float32, 0, newShape.NumElements())
+		// For each row, append all tensors' data
+		numRows := refShape.NumElements() / refShape[dim]
+		for row := 0; row < numRows; row++ {
+			for _, t := range tensors {
+				tData := t.DataFloat32()
+				rowSize := t.Shape()[dim]
+				start := row * rowSize
+				newData = append(newData, tData[start:start+rowSize]...)
+			}
+		}
+		return cpu.NewTensor(newShape, newData)
+	}
+	
+	// General case: just concatenate flat data (works for dim=0)
+	if dim == 0 {
+		newData := make([]float32, 0, newShape.NumElements())
+		for _, t := range tensors {
+			newData = append(newData, t.DataFloat32()...)
+		}
+		return cpu.NewTensor(newShape, newData)
+	}
+	
+	// General case: concat along middle dimension
+	// Calculate outer (before dim), inner (after dim) sizes
+	outerSize := 1
+	for i := 0; i < dim; i++ {
+		outerSize *= refShape[i]
+	}
+	innerSize := 1
+	for i := dim + 1; i < len(refShape); i++ {
+		innerSize *= refShape[i]
+	}
+	
+	newData := make([]float32, 0, newShape.NumElements())
+	
+	// For each outer index, copy all tensors' slices
+	for outer := 0; outer < outerSize; outer++ {
+		for _, t := range tensors {
+			tData := t.DataFloat32()
+			dimSize := t.Shape()[dim]
+			sliceSize := dimSize * innerSize
+			start := outer * sliceSize
+			newData = append(newData, tData[start:start+sliceSize]...)
+		}
+	}
+	
+	return cpu.NewTensor(newShape, newData)
+}

+ 99 - 0
pkg/backend/cpu/simd.go

@@ -0,0 +1,99 @@
+package cpu
+
+import "unsafe"
+
+// DotFloat32 computes the dot product between two float32 slices.
+// Dispatch order: AVX-512 -> AVX2 -> scalar fallback.
+func DotFloat32(a, b []float32) float32 {
+	if len(a) != len(b) {
+		panic("DotFloat32: mismatched slice lengths")
+	}
+	if len(a) == 0 {
+		return 0
+	}
+
+	if hasAVX512Kernel && SupportsAVX512() && len(a) >= 16 {
+		return dotFloat32AVX512(a, b)
+	}
+	if hasAVX2Kernel && SupportsAVX2() && len(a) >= 8 {
+		return dotFloat32AVX2(a, b)
+	}
+
+	return dotFloat32Scalar(a, b)
+}
+
+func DotFloat32Ptr(a, b *float32, n int) float32 {
+	if n <= 0 {
+		return 0
+	}
+	if a == nil || b == nil {
+		panic("DotFloat32Ptr: nil pointer")
+	}
+
+	if hasAVX512Kernel && SupportsAVX512() && n >= 16 {
+		return dotAVX512(a, b, n)
+	}
+	if hasAVX2Kernel && SupportsAVX2() && n >= 8 {
+		return dotAVX2(a, b, n)
+	}
+
+	aa := unsafe.Slice(a, n)
+	bb := unsafe.Slice(b, n)
+	return dotFloat32Scalar(aa, bb)
+}
+
+func dotFloat32Scalar(a, b []float32) float32 {
+	var sum float32
+	for i := 0; i < len(a); i++ {
+		sum += a[i] * b[i]
+	}
+	return sum
+}
+
+// Axpy performs y += alpha * x for float32 slices of equal length.
+// Intended for small vector adds in attention/value accumulation.
+func Axpy(alpha float32, x, y []float32) {
+	if len(x) != len(y) {
+		panic("Axpy: mismatched slice lengths")
+	}
+	if len(x) == 0 {
+		return
+	}
+
+	if hasAVX512Kernel && SupportsAVX512() && len(x) >= 16 {
+		axpyFloat32AVX512(alpha, x, y)
+		return
+	}
+	if hasAVX2Kernel && SupportsAVX2() && len(x) >= 8 {
+		axpyFloat32AVX2(alpha, x, y)
+		return
+	}
+
+	for i := 0; i < len(x); i++ {
+		y[i] += alpha * x[i]
+	}
+}
+
+func AxpyPtr(alpha float32, x, y *float32, n int) {
+	if n <= 0 {
+		return
+	}
+	if x == nil || y == nil {
+		panic("AxpyPtr: nil pointer")
+	}
+
+	if hasAVX512Kernel && SupportsAVX512() && n >= 16 {
+		axpyAVX512(alpha, x, y, n)
+		return
+	}
+	if hasAVX2Kernel && SupportsAVX2() && n >= 8 {
+		axpyAVX2(alpha, x, y, n)
+		return
+	}
+
+	xs := unsafe.Slice(x, n)
+	ys := unsafe.Slice(y, n)
+	for i := 0; i < n; i++ {
+		ys[i] += alpha * xs[i]
+	}
+}

+ 28 - 0
pkg/backend/cpu/simd_avx2.go

@@ -0,0 +1,28 @@
+//go:build amd64
+
+package cpu
+
+// hasAVX2Kernel signals whether the AVX2 assembly implementation is available.
+const hasAVX2Kernel = true
+
+//go:noescape
+func dotAVX2(a *float32, b *float32, n int) float32
+
+//go:noescape
+func axpyAVX2(alpha float32, x *float32, y *float32, n int)
+
+func dotFloat32AVX2(a, b []float32) float32 {
+	if len(a) == 0 {
+		return 0
+	}
+	return dotAVX2(&a[0], &b[0], len(a))
+}
+
+func axpyFloat32AVX2(alpha float32, x, y []float32) {
+	if len(x) == 0 {
+		return
+	}
+	axpyAVX2(alpha, &x[0], &y[0], len(x))
+}
+
+

+ 111 - 0
pkg/backend/cpu/simd_avx2.s

@@ -0,0 +1,111 @@
+//go:build amd64
+// +build amd64
+
+#include "textflag.h"
+
+// func dotAVX2(a *float32, b *float32, n int) float32
+TEXT ·dotAVX2(SB), NOSPLIT, $0-24
+	MOVQ a+0(FP), DI
+	MOVQ b+8(FP), SI
+	MOVQ n+16(FP), CX
+
+	// Accumulator ymm0 = 0
+	VXORPS Y0, Y0, Y0
+
+	TESTQ CX, CX
+	JLE dot_zero
+
+// Process 8 floats per iteration
+loop8:
+	CMPQ CX, $8
+	JL fold
+	VMOVUPS (DI), Y1
+	VMOVUPS (SI), Y2
+	VFMADD231PS Y1, Y2, Y0 // Y0 += Y1 * Y2
+	ADDQ $32, DI
+	ADDQ $32, SI
+	SUBQ $8, CX
+	JMP loop8
+
+// Fold ymm0 upper half into xmm0 before handling tails.
+fold:
+	VEXTRACTF128 $1, Y0, X1
+	VADDPS X1, X0, X0
+
+// Scalar tail
+loop1:
+	CMPQ CX, $4
+	JL loop_scalar
+	VMOVUPS (DI), X1
+	VMOVUPS (SI), X2
+	VFMADD231PS X1, X2, X0
+	ADDQ $16, DI
+	ADDQ $16, SI
+	SUBQ $4, CX
+	JMP loop1
+
+loop_scalar:
+	TESTQ CX, CX
+	JE reduce4
+	MOVSS (DI), X1
+	MOVSS (SI), X2
+	VFMADD231SS X1, X2, X0
+	ADDQ $4, DI
+	ADDQ $4, SI
+	DECQ CX
+	JMP loop_scalar
+
+// Horizontal sum of xmm0 (4 lanes) to scalar
+reduce4:
+	VMOVHLPS X0, X0, X1
+	VADDPS X1, X0, X0
+	VPSHUFD $0xB1, X0, X1
+	VADDPS X1, X0, X0
+	MOVSS X0, ret+24(FP)
+	VZEROUPPER
+	RET
+
+dot_zero:
+	VXORPS X0, X0, X0
+	MOVSS X0, ret+24(FP)
+	RET
+
+
+// func axpyAVX2(alpha float32, x *float32, y *float32, n int)
+TEXT ·axpyAVX2(SB), NOSPLIT, $0-28
+	MOVSS alpha+0(FP), X0
+	VBROADCASTSS X0, Y0
+	MOVQ x+8(FP), DI
+	MOVQ y+16(FP), SI
+	MOVQ n+24(FP), CX
+
+	TESTQ CX, CX
+	JLE axpy_done
+
+axpy_loop8:
+	CMPQ CX, $8
+	JL axpy_loop1
+	VMOVUPS (DI), Y1
+	VMOVUPS (SI), Y2
+	VFMADD231PS Y0, Y1, Y2
+	VMOVUPS Y2, (SI)
+	ADDQ $32, DI
+	ADDQ $32, SI
+	SUBQ $8, CX
+	JMP axpy_loop8
+
+axpy_loop1:
+	TESTQ CX, CX
+	JE axpy_done
+	MOVSS (DI), X1
+	MOVSS (SI), X2
+	VFMADD231SS X0, X1, X2
+	MOVSS X2, (SI)
+	ADDQ $4, DI
+	ADDQ $4, SI
+	DECQ CX
+	JMP axpy_loop1
+
+axpy_done:
+	RET
+

+ 29 - 0
pkg/backend/cpu/simd_avx512.go

@@ -0,0 +1,29 @@
+//go:build amd64
+
+package cpu
+
+// hasAVX512Kernel signals whether the AVX-512 assembly implementation is
+// available in this build tag.
+const hasAVX512Kernel = true
+
+//go:noescape
+func dotAVX512(a *float32, b *float32, n int) float32
+
+//go:noescape
+func axpyAVX512(alpha float32, x *float32, y *float32, n int)
+
+func dotFloat32AVX512(a, b []float32) float32 {
+	if len(a) == 0 {
+		return 0
+	}
+	return dotAVX512(&a[0], &b[0], len(a))
+}
+
+func axpyFloat32AVX512(alpha float32, x, y []float32) {
+	if len(x) == 0 {
+		return
+	}
+	axpyAVX512(alpha, &x[0], &y[0], len(x))
+}
+
+

+ 114 - 0
pkg/backend/cpu/simd_avx512.s

@@ -0,0 +1,114 @@
+//go:build amd64
+// +build amd64
+
+#include "textflag.h"
+
+// func dotAVX512(a *float32, b *float32, n int) float32
+TEXT ·dotAVX512(SB), NOSPLIT, $0-24
+	// Load args
+	MOVQ a+0(FP), DI
+	MOVQ b+8(FP), SI
+	MOVQ n+16(FP), CX
+
+	// Accumulator
+	VXORPS Z0, Z0, Z0
+
+	// If n <= 0 return 0
+	TESTQ CX, CX
+	JLE dot512_zero
+
+// Process 16 floats per iteration
+loop16:
+	CMPQ CX, $16
+	JL fold512
+	VMOVUPS (DI), Z1
+	VMOVUPS (SI), Z2
+	VFMADD231PS Z1, Z2, Z0 // Z0 += Z1 * Z2
+	ADDQ $64, DI
+	ADDQ $64, SI
+	SUBQ $16, CX
+	JMP loop16
+
+// Fold zmm0 -> xmm0 before handling tails.
+fold512:
+	VEXTRACTF32X8 $1, Z0, Y1
+	VADDPS Y1, Y0, Y0
+	VEXTRACTF128 $1, Y0, X1
+	VADDPS X1, X0, X0
+
+// Scalar tail
+loop1:
+	CMPQ CX, $4
+	JL loop1_scalar
+	VMOVUPS (DI), X1
+	VMOVUPS (SI), X2
+	VFMADD231PS X1, X2, X0
+	ADDQ $16, DI
+	ADDQ $16, SI
+	SUBQ $4, CX
+	JMP loop1
+
+loop1_scalar:
+	TESTQ CX, CX
+	JE reduce4
+	MOVSS (DI), X1
+	MOVSS (SI), X2
+	VFMADD231SS X1, X2, X0
+	ADDQ $4, DI
+	ADDQ $4, SI
+	DECQ CX
+	JMP loop1_scalar
+
+// Horizontal sum of xmm0 (4 lanes) to scalar.
+reduce4:
+	VMOVHLPS X0, X0, X1
+	VADDPS X1, X0, X0
+	VPSHUFD $0xB1, X0, X1
+	VADDPS X1, X0, X0
+	MOVSS X0, ret+24(FP)
+	VZEROUPPER
+	RET
+
+dot512_zero:
+	VXORPS X0, X0, X0
+	MOVSS X0, ret+24(FP)
+	RET
+
+
+// func axpyAVX512(alpha float32, x *float32, y *float32, n int)
+TEXT ·axpyAVX512(SB), NOSPLIT, $0-28
+	MOVSS alpha+0(FP), X0
+	VBROADCASTSS X0, Z0
+	MOVQ x+8(FP), DI
+	MOVQ y+16(FP), SI
+	MOVQ n+24(FP), CX
+
+	TESTQ CX, CX
+	JLE axpy512_done
+
+axpy512_loop16:
+	CMPQ CX, $16
+	JL axpy512_loop1
+	VMOVUPS (DI), Z1
+	VMOVUPS (SI), Z2
+	VFMADD231PS Z0, Z1, Z2
+	VMOVUPS Z2, (SI)
+	ADDQ $64, DI
+	ADDQ $64, SI
+	SUBQ $16, CX
+	JMP axpy512_loop16
+
+axpy512_loop1:
+	TESTQ CX, CX
+	JE axpy512_done
+	MOVSS (DI), X1
+	MOVSS (SI), X2
+	VFMADD231SS X0, X1, X2
+	MOVSS X2, (SI)
+	ADDQ $4, DI
+	ADDQ $4, SI
+	DECQ CX
+	JMP axpy512_loop1
+
+axpy512_done:
+	RET

+ 12 - 0
pkg/backend/cpu/simd_noasm.go

@@ -0,0 +1,12 @@
+//go:build !amd64
+
+package cpu
+
+const hasAVX512Kernel = false
+const hasAVX2Kernel = false
+
+func dotFloat32AVX512(_, _ []float32) float32 { return 0 }
+func dotFloat32AVX2(_, _ []float32) float32   { return 0 }
+
+func axpyFloat32AVX512(_ float32, _, _ []float32) {}
+func axpyFloat32AVX2(_ float32, _, _ []float32)   {}

+ 30 - 0
pkg/backend/cpu/simd_random_test.go

@@ -0,0 +1,30 @@
+package cpu
+
+import (
+	"math"
+	"math/rand"
+	"testing"
+)
+
+func TestDotFloat32MatchesScalarRandomLengths(t *testing.T) {
+	rng := rand.New(rand.NewSource(2))
+	lengths := []int{1, 2, 3, 7, 8, 9, 15, 16, 17, 31, 32, 33, 63, 64, 65, 127, 128, 129}
+	for _, n := range lengths {
+		a := make([]float32, n)
+		b := make([]float32, n)
+		for i := 0; i < n; i++ {
+			a[i] = rng.Float32()*2 - 1
+			b[i] = rng.Float32()*2 - 1
+		}
+
+		got := DotFloat32(a, b)
+		var want float32
+		for i := 0; i < n; i++ {
+			want += a[i] * b[i]
+		}
+		if diff := math.Abs(float64(got - want)); diff > 1e-4 {
+			t.Fatalf("n=%d: got=%v want=%v diff=%g", n, got, want, diff)
+		}
+	}
+}
+

+ 138 - 0
pkg/backend/cpu/simd_test.go

@@ -0,0 +1,138 @@
+package cpu
+
+import (
+	"math/rand"
+	"testing"
+)
+
+func TestDotFloat32MatchesScalar(t *testing.T) {
+	a := []float32{1, 2, 3, 4, 5, 6, 7, 8}
+	b := []float32{8, 7, 6, 5, 4, 3, 2, 1}
+
+	got := DotFloat32(a, b)
+	want := dotFloat32Scalar(a, b)
+
+	if diff := absDiff(got, want); diff > 1e-5 {
+		t.Fatalf("dot mismatch: got %f want %f (diff %f)", got, want, diff)
+	}
+}
+
+func TestDotFloat32ZeroLen(t *testing.T) {
+	got := DotFloat32(nil, nil)
+	if got != 0 {
+		t.Fatalf("expected 0 for empty slices, got %f", got)
+	}
+}
+
+func BenchmarkDotFloat32(b *testing.B) {
+	size := 256
+	a := make([]float32, size)
+	bb := make([]float32, size)
+	for i := 0; i < size; i++ {
+		a[i] = rand.Float32()
+		bb[i] = rand.Float32()
+	}
+
+	b.ReportAllocs()
+	b.SetBytes(int64(size * 4 * 2))
+	b.ResetTimer()
+
+	for i := 0; i < b.N; i++ {
+		_ = DotFloat32(a, bb)
+	}
+}
+
+func BenchmarkDotFloat32Scalar(b *testing.B) {
+	size := 256
+	a := make([]float32, size)
+	bb := make([]float32, size)
+	for i := 0; i < size; i++ {
+		a[i] = rand.Float32()
+		bb[i] = rand.Float32()
+	}
+
+	b.ReportAllocs()
+	b.SetBytes(int64(size * 4 * 2))
+	b.ResetTimer()
+
+	for i := 0; i < b.N; i++ {
+		_ = dotFloat32Scalar(a, bb)
+	}
+}
+
+func absDiff(a, b float32) float32 {
+	if a > b {
+		return a - b
+	}
+	return b - a
+}
+
+func TestAxpy(t *testing.T) {
+	x := []float32{1, 2, 3}
+	y := []float32{4, 5, 6}
+	Axpy(2, x, y)
+	want := []float32{6, 9, 12}
+	for i := range y {
+		if diff := absDiff(y[i], want[i]); diff > 1e-6 {
+			t.Fatalf("axpy mismatch at %d: got %f want %f", i, y[i], want[i])
+		}
+	}
+}
+
+func BenchmarkDotFloat32Ptr(b *testing.B) {
+	size := 256
+	a := make([]float32, size)
+	bb := make([]float32, size)
+	for i := 0; i < size; i++ {
+		a[i] = rand.Float32()
+		bb[i] = rand.Float32()
+	}
+
+	b.ReportAllocs()
+	b.SetBytes(int64(size * 4 * 2))
+	b.ResetTimer()
+
+	for i := 0; i < b.N; i++ {
+		_ = DotFloat32Ptr(&a[0], &bb[0], size)
+	}
+}
+
+func BenchmarkAxpy(b *testing.B) {
+	size := 256
+	x := make([]float32, size)
+	y := make([]float32, size)
+	for i := 0; i < size; i++ {
+		x[i] = rand.Float32()
+		y[i] = rand.Float32()
+	}
+	alpha := float32(0.7)
+
+	b.ReportAllocs()
+	b.SetBytes(int64(size * 4 * 2))
+	b.ResetTimer()
+
+	for i := 0; i < b.N; i++ {
+		Axpy(alpha, x, y)
+	}
+}
+
+func BenchmarkAxpyPtr(b *testing.B) {
+	size := 256
+	x := make([]float32, size)
+	y := make([]float32, size)
+	for i := 0; i < size; i++ {
+		x[i] = rand.Float32()
+		y[i] = rand.Float32()
+	}
+	alpha := float32(0.7)
+
+	b.ReportAllocs()
+	b.SetBytes(int64(size * 4 * 2))
+	b.ResetTimer()
+
+	for i := 0; i < b.N; i++ {
+		AxpyPtr(alpha, &x[0], &y[0], size)
+	}
+}
+
+

+ 277 - 0
pkg/backend/cpu/tensor.go

@@ -0,0 +1,277 @@
+package cpu
+
+import (
+	"fmt"
+	"unsafe"
+
+	"makarna/pkg/tensor"
+)
+
+// Tensor implements tensor.Tensor for CPU
+type Tensor struct {
+	shape       tensor.Shape
+	dtype       tensor.DType
+	dataFloat32 []float32
+	dataUint16  []uint16
+
+	dataQ4_K []tensor.BlockQ4_K
+	dataQ3_K []tensor.BlockQ3_K
+	dataQ5_K []tensor.BlockQ5_K
+	dataQ6_K []tensor.BlockQ6_K
+	dataQ8_K []tensor.BlockQ8_K
+	dataQ2_K []tensor.BlockQ2_K
+}
+
+func (t *Tensor) Placement() tensor.DevicePlacement {
+	return tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}
+}
+
+// NewTensor creates a new Float32 tensor
+func NewTensor(shape tensor.Shape, data []float32) *Tensor {
+	if data == nil {
+		data = make([]float32, shape.NumElements())
+	}
+	return &Tensor{
+		shape:       shape,
+		dtype:       tensor.Float32,
+		dataFloat32: data,
+	}
+}
+
+func NewTensorU16(shape tensor.Shape, dtype tensor.DType, data []uint16) (*Tensor, error) {
+	if dtype != tensor.Float16 && dtype != tensor.BFloat16 {
+		return nil, fmt.Errorf("unsupported u16 tensor dtype: %v", dtype)
+	}
+	if data == nil {
+		data = make([]uint16, shape.NumElements())
+	}
+	if len(data) != shape.NumElements() {
+		return nil, fmt.Errorf("size mismatch for u16 tensor: expected %d, got %d", shape.NumElements(), len(data))
+	}
+	return &Tensor{shape: shape, dtype: dtype, dataUint16: data}, nil
+}
+
+// NewTensorFromBytes creates a tensor from raw bytes (zero-copy mmap)
+func NewTensorFromBytes(shape tensor.Shape, dtype tensor.DType, data []byte) (*Tensor, error) {
+	if len(data) == 0 {
+		return nil, fmt.Errorf("empty data for tensor")
+	}
+
+	ptr := unsafe.Pointer(&data[0])
+	t := &Tensor{
+		shape: shape,
+		dtype: dtype,
+	}
+
+	switch dtype {
+	case tensor.Float32:
+		expectedSize := shape.NumElements() * 4
+		if len(data) != expectedSize {
+			return nil, fmt.Errorf("size mismatch for F32 tensor: expected %d bytes, got %d", expectedSize, len(data))
+		}
+		// Zero-copy cast
+		t.dataFloat32 = unsafe.Slice((*float32)(ptr), shape.NumElements())
+
+	case tensor.Float16, tensor.BFloat16:
+		expectedSize := shape.NumElements() * 2
+		if len(data) != expectedSize {
+			return nil, fmt.Errorf("size mismatch for %v tensor: expected %d bytes, got %d", dtype, expectedSize, len(data))
+		}
+		t.dataUint16 = unsafe.Slice((*uint16)(ptr), shape.NumElements())
+
+	case tensor.Q4_K:
+		const blockSize = 256
+		const blockBytes = 144 // 2(D)+2(DMin)+12(scales)+128(qs)
+
+		elemCount := shape.NumElements()
+		if elemCount%blockSize != 0 {
+			return nil, fmt.Errorf("Q4_K tensor elements must be multiple of %d", blockSize)
+		}
+
+		numBlocks := elemCount / blockSize
+		expectedSize := numBlocks * blockBytes
+
+		if len(data) != expectedSize {
+			return nil, fmt.Errorf("size mismatch for Q4_K tensor: expected %d bytes, got %d", expectedSize, len(data))
+		}
+
+		t.dataQ4_K = unsafe.Slice((*tensor.BlockQ4_K)(ptr), numBlocks)
+
+	case tensor.Q8_K:
+		const blockSize = 256
+		const blockBytes = 292 // 4(D)+256(qs)+32(bsums)
+
+		elemCount := shape.NumElements()
+		if elemCount%blockSize != 0 {
+			return nil, fmt.Errorf("Q8_K tensor elements must be multiple of %d", blockSize)
+		}
+
+		numBlocks := elemCount / blockSize
+		expectedSize := numBlocks * blockBytes
+
+		if len(data) != expectedSize {
+			return nil, fmt.Errorf("size mismatch for Q8_K tensor: expected %d bytes, got %d", expectedSize, len(data))
+		}
+
+		t.dataQ8_K = unsafe.Slice((*tensor.BlockQ8_K)(ptr), numBlocks)
+
+	case tensor.Q3_K:
+		const blockSize = 256
+		const blockBytes = 110 // 32(hmask)+64(qs)+12(scales)+2(D)
+
+		elemCount := shape.NumElements()
+		if elemCount%blockSize != 0 {
+			return nil, fmt.Errorf("Q3_K tensor elements must be multiple of %d", blockSize)
+		}
+
+		numBlocks := elemCount / blockSize
+		expectedSize := numBlocks * blockBytes
+
+		if len(data) != expectedSize {
+			return nil, fmt.Errorf("size mismatch for Q3_K tensor: expected %d bytes, got %d", expectedSize, len(data))
+		}
+
+		t.dataQ3_K = unsafe.Slice((*tensor.BlockQ3_K)(ptr), numBlocks)
+
+	case tensor.Q5_K:
+		const blockSize = 256
+		const blockBytes = 176
+
+		elemCount := shape.NumElements()
+		if elemCount%blockSize != 0 {
+			return nil, fmt.Errorf("Q5_K tensor elements must be multiple of %d", blockSize)
+		}
+
+		numBlocks := elemCount / blockSize
+		expectedSize := numBlocks * blockBytes
+
+		if len(data) != expectedSize {
+			return nil, fmt.Errorf("size mismatch for Q5_K tensor: expected %d bytes, got %d", expectedSize, len(data))
+		}
+
+		t.dataQ5_K = unsafe.Slice((*tensor.BlockQ5_K)(ptr), numBlocks)
+
+	case tensor.Q6_K:
+		const blockSize = 256
+		const blockBytes = 210 // 128(ql)+64(qh)+16(scales)+2(D)
+
+		elemCount := shape.NumElements()
+		if elemCount%blockSize != 0 {
+			return nil, fmt.Errorf("Q6_K tensor elements must be multiple of %d", blockSize)
+		}
+
+		numBlocks := elemCount / blockSize
+		expectedSize := numBlocks * blockBytes
+
+		if len(data) != expectedSize {
+			return nil, fmt.Errorf("size mismatch for Q6_K tensor: expected %d bytes, got %d", expectedSize, len(data))
+		}
+
+		t.dataQ6_K = unsafe.Slice((*tensor.BlockQ6_K)(ptr), numBlocks)
+
+	case tensor.Q2_K:
+		const blockSize = 256
+		const blockBytes = 84 // 16(scales)+64(qs)+2(D)+2(DMin)
+
+		elemCount := shape.NumElements()
+		if elemCount%blockSize != 0 {
+			return nil, fmt.Errorf("Q2_K tensor elements must be multiple of %d", blockSize)
+		}
+
+		numBlocks := elemCount / blockSize
+		expectedSize := numBlocks * blockBytes
+
+		if len(data) != expectedSize {
+			return nil, fmt.Errorf("size mismatch for Q2_K tensor: expected %d bytes, got %d", expectedSize, len(data))
+		}
+
+		t.dataQ2_K = unsafe.Slice((*tensor.BlockQ2_K)(ptr), numBlocks)
+
+	default:
+		return nil, fmt.Errorf("unsupported tensor dtype: %v", dtype)
+	}
+
+	return t, nil
+}
+
+// Shape returns tensor dimensions
+func (t *Tensor) Shape() tensor.Shape {
+	return t.shape
+}
+
+// DType returns data type
+func (t *Tensor) DType() tensor.DType {
+	return t.dtype
+}
+
+// Device returns CPU
+func (t *Tensor) Device() tensor.DeviceType {
+	return tensor.CPU
+}
+
+// Data returns raw data pointer.
+// Use DataFloat32/DataQ4_K etc for type-safe access.
+func (t *Tensor) Data() interface{} {
+	switch t.dtype {
+	case tensor.Float32:
+		return unsafe.Pointer(&t.dataFloat32[0])
+	case tensor.Float16, tensor.BFloat16:
+		return unsafe.Pointer(&t.dataUint16[0])
+
+	case tensor.Q4_K:
+		return unsafe.Pointer(&t.dataQ4_K[0])
+
+	case tensor.Q3_K:
+		return unsafe.Pointer(&t.dataQ3_K[0])
+
+	case tensor.Q5_K:
+		return unsafe.Pointer(&t.dataQ5_K[0])
+
+	case tensor.Q6_K:
+		return unsafe.Pointer(&t.dataQ6_K[0])
+
+	case tensor.Q8_K:
+		return unsafe.Pointer(&t.dataQ8_K[0])
+	case tensor.Q2_K:
+		return unsafe.Pointer(&t.dataQ2_K[0])
+	default:
+		panic(fmt.Sprintf("internal error: unsupported dtype %v in Data()", t.dtype))
+	}
+}
+
+// DataFloat32 returns the underlying float32 slice directly
+func (t *Tensor) DataFloat32() []float32 {
+	return t.dataFloat32
+}
+
+func (t *Tensor) DataUint16() []uint16 {
+	return t.dataUint16
+}
+
+// DataQ4_K returns the underlying Q4_K block slice
+func (t *Tensor) DataQ4_K() []tensor.BlockQ4_K {
+	return t.dataQ4_K
+}
+
+// DataQ3_K returns the underlying Q3_K block slice
+func (t *Tensor) DataQ3_K() []tensor.BlockQ3_K {
+	return t.dataQ3_K
+}
+func (t *Tensor) DataQ5_K() []tensor.BlockQ5_K {
+	return t.dataQ5_K
+}
+
+// DataQ6_K returns the underlying Q6_K block slice
+func (t *Tensor) DataQ6_K() []tensor.BlockQ6_K {
+	return t.dataQ6_K
+}
+
+// DataQ8_K returns the underlying Q8_K block slice
+func (t *Tensor) DataQ8_K() []tensor.BlockQ8_K {
+	return t.dataQ8_K
+}
+
+// DataQ2_K returns the underlying Q2_K block slice
+func (t *Tensor) DataQ2_K() []tensor.BlockQ2_K {
+	return t.dataQ2_K
+}

+ 36 - 0
pkg/backend/cpu/threads.go

@@ -0,0 +1,36 @@
+package cpu
+
+import (
+	"math"
+	"runtime"
+)
+
+var maxThreads int
+
+// init defaults to 90% of available cores (at least 1).
+func init() {
+	SetMaxThreads(-1)
+}
+
+// SetMaxThreads updates the maximum CPU worker count used by CPU kernels.
+// Passing n <= 0 picks 90% of available cores. Clamped to [1, NumCPU].
+// It also sets GOMAXPROCS to match to avoid oversubscription.
+func SetMaxThreads(n int) {
+	cores := runtime.NumCPU()
+	if cores < 1 {
+		cores = 1
+	}
+	if n <= 0 || n > cores {
+		n = int(math.Ceil(0.9 * float64(cores)))
+		if n < 1 {
+			n = 1
+		}
+	}
+	maxThreads = n
+	runtime.GOMAXPROCS(n)
+}
+
+// MaxThreads returns the configured worker count.
+func MaxThreads() int {
+	return maxThreads
+}

+ 141 - 0
pkg/backend/cuda/attention_bench_test.go

@@ -0,0 +1,141 @@
+//go:build cuda
+
+package cuda
+
+import (
+	"fmt"
+	"math"
+	"math/rand"
+	"testing"
+	"unsafe"
+
+	"makarna/pkg/tensor"
+)
+
+func BenchmarkPagedAttentionF16KV_Decode(b *testing.B) {
+	if !Available() {
+		b.Skip("cuda not available")
+	}
+
+	const (
+		gpu        = 0
+		numHeads   = 32
+		numKVHeads = 8
+		headDim    = 128
+		blockSize  = 16
+	)
+	const seqLen = 1
+
+	kvDim := numKVHeads * headDim
+	scale := float32(1.0 / math.Sqrt(float64(headDim)))
+
+	kvLens := []int{256, 1024, 2048, 4096}
+	for _, kvLen := range kvLens {
+		b.Run(fmt.Sprintf("kvLen=%d", kvLen), func(b *testing.B) {
+			numBlocks := (kvLen + blockSize - 1) / blockSize
+			if numBlocks <= 0 {
+				b.Fatal("invalid numBlocks")
+			}
+
+			q, err := NewTensor(tensor.Shape{seqLen, numHeads * headDim}, tensor.Float32, gpu)
+			if err != nil {
+				b.Skipf("cuda alloc failed: %v", err)
+			}
+			defer q.Free()
+
+			out, err := NewTensor(tensor.Shape{seqLen, numHeads * headDim}, tensor.Float32, gpu)
+			if err != nil {
+				b.Skipf("cuda alloc failed: %v", err)
+			}
+			defer out.Free()
+
+			// Deterministic Q.
+			r := rand.New(rand.NewSource(1))
+			hostQ := make([]float32, numHeads*headDim)
+			for i := range hostQ {
+				hostQ[i] = r.Float32()*2 - 1
+			}
+			if err := q.CopyFrom(hostQ); err != nil {
+				b.Fatalf("CopyFrom Q: %v", err)
+			}
+
+			// Allocate K/V blocks as F16 and zero-initialize (setup only).
+			zeroHalf := make([]uint16, blockSize*kvDim)
+			kPtrs := make([]uintptr, numBlocks)
+			vPtrs := make([]uintptr, numBlocks)
+			kBlocks := make([]*Tensor, numBlocks)
+			vBlocks := make([]*Tensor, numBlocks)
+			for i := 0; i < numBlocks; i++ {
+				kb, err := NewTensor(tensor.Shape{blockSize, kvDim}, tensor.Float16, gpu)
+				if err != nil {
+					b.Skipf("cuda alloc failed: %v", err)
+				}
+				vb, err := NewTensor(tensor.Shape{blockSize, kvDim}, tensor.Float16, gpu)
+				if err != nil {
+					kb.Free()
+					b.Skipf("cuda alloc failed: %v", err)
+				}
+				kBlocks[i] = kb
+				vBlocks[i] = vb
+				kPtrs[i] = uintptr(kb.Data().(unsafe.Pointer))
+				vPtrs[i] = uintptr(vb.Data().(unsafe.Pointer))
+
+				if err := MemcpyH2D(kb.Data().(unsafe.Pointer), unsafe.Pointer(&zeroHalf[0]), uintptr(len(zeroHalf)*2), gpu); err != nil {
+					b.Fatalf("zero K: %v", err)
+				}
+				if err := MemcpyH2D(vb.Data().(unsafe.Pointer), unsafe.Pointer(&zeroHalf[0]), uintptr(len(zeroHalf)*2), gpu); err != nil {
+					b.Fatalf("zero V: %v", err)
+				}
+			}
+			defer func() {
+				for i := range kBlocks {
+					if kBlocks[i] != nil {
+						kBlocks[i].Free()
+					}
+					if vBlocks[i] != nil {
+						vBlocks[i].Free()
+					}
+				}
+			}()
+
+			kDev, err := AllocAndCopyPtrTable(kPtrs, gpu)
+			if err != nil {
+				b.Fatalf("AllocAndCopyPtrTable K: %v", err)
+			}
+			defer FreeDevicePtr(kDev)
+
+			vDev, err := AllocAndCopyPtrTable(vPtrs, gpu)
+			if err != nil {
+				FreeDevicePtr(kDev)
+				b.Fatalf("AllocAndCopyPtrTable V: %v", err)
+			}
+			defer FreeDevicePtr(vDev)
+
+			startPos := kvLen - 1
+			b.ResetTimer()
+
+			for i := 0; i < b.N; i++ {
+				if err := PagedAttentionF32F16KV(
+					q.Data().(unsafe.Pointer),
+					kDev,
+					vDev,
+					out.Data().(unsafe.Pointer),
+					seqLen,
+					kvLen,
+					numHeads,
+					numKVHeads,
+					headDim,
+					blockSize,
+					scale,
+					startPos,
+					gpu,
+				); err != nil {
+					b.Fatalf("PagedAttentionF32F16KV: %v", err)
+				}
+			}
+
+			// Ensure all kernels are complete before timing finishes.
+			_ = Synchronize(gpu)
+		})
+	}
+}

+ 1566 - 0
pkg/backend/cuda/cuda.go

@@ -0,0 +1,1566 @@
+//go:build cuda
+
+package cuda
+
+/*
+#cgo CFLAGS: -I${SRCDIR}
+#cgo LDFLAGS: -L${SRCDIR}/../../..//build/cuda -Wl,-Bstatic -lmakarna_cuda -Wl,-Bdynamic
+#cgo LDFLAGS: -L/usr/local/cuda/lib64 -lcudart -lstdc++ -lm
+#cgo LDFLAGS: -Wl,-rpath,${SRCDIR}/../../..//build/cuda -Wl,-rpath,/usr/local/cuda/lib64
+#include "kernels.h"
+*/
+import "C"
+import (
+	"errors"
+	"fmt"
+	"runtime"
+	"time"
+	"unsafe"
+
+	"makarna/pkg/profile"
+	"makarna/pkg/tensor"
+)
+
+func syncIfProfiling(gpu int) error {
+	if !profile.Enabled() {
+		return nil
+	}
+	return Synchronize(gpu)
+}
+
+// Ensure Interface Compliance
+var _ tensor.Tensor = (*Tensor)(nil)
+
+// Storage holds the underlying GPU memory with reference counting.
+// Multiple Tensors can share the same Storage (e.g., views, reshapes).
+// Memory is freed only when all references are gone.
+type Storage struct {
+	ptr unsafe.Pointer
+	gpu int
+	// Note: We rely on Go's GC and SetFinalizer for ref counting.
+	// Each Tensor that shares this storage keeps a reference to it.
+	// When the last Tensor is GC'd, the Storage becomes unreachable,
+	// and its finalizer frees the GPU memory.
+}
+
+// newStorage creates a new Storage and sets up its finalizer
+func newStorage(ptr unsafe.Pointer, gpu int) *Storage {
+	s := &Storage{ptr: ptr, gpu: gpu}
+	runtime.SetFinalizer(s, func(st *Storage) {
+		_ = C.cuda_set_device(C.int(st.gpu))
+		C.cuda_free(st.ptr)
+	})
+	return s
+}
+
+type Tensor struct {
+	shape   tensor.Shape
+	dtype   tensor.DType
+	storage *Storage       // Shared storage with ref counting
+	ptr     unsafe.Pointer // Pointer into storage (may be offset for slices)
+	gpu     int
+	// ownsStorage indicates whether this Tensor is responsible for explicitly
+	// freeing the underlying CUDA allocation.
+	// Views/reshapes must not free shared storage because they may outlive the base
+	// tensor (e.g. scratch-buffer views).
+	ownsStorage bool
+}
+
+// NewTensor allocates memory on the GPU
+func NewTensor(shape tensor.Shape, dtype tensor.DType, gpu int) (*Tensor, error) {
+	if dtype != tensor.Float32 && dtype != tensor.Float16 && dtype != tensor.BFloat16 {
+		return nil, errors.New("unsupported dtype on CUDA")
+	}
+
+	if gpu < 0 {
+		gpu = 0
+	}
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return nil, errors.New("failed to set cuda device")
+	}
+
+	size := shape.NumElements() * dtype.Size()
+	ptr := C.cuda_malloc(C.size_t(size))
+	if ptr == nil {
+		return nil, errors.New("cuda malloc failed")
+	}
+
+	storage := newStorage(ptr, gpu)
+
+	t := &Tensor{
+		shape:       shape,
+		dtype:       dtype,
+		storage:     storage,
+		ptr:         ptr,
+		gpu:         gpu,
+		ownsStorage: true,
+	}
+
+	return t, nil
+}
+
+func (t *Tensor) Shape() tensor.Shape {
+	return t.shape
+}
+
+func (t *Tensor) DType() tensor.DType {
+	return t.dtype
+}
+
+func (t *Tensor) Device() tensor.DeviceType {
+	return tensor.CUDA
+}
+
+// GPU returns the device ordinal.
+func (t *Tensor) GPU() int {
+	return t.gpu
+}
+
+func (t *Tensor) Placement() tensor.DevicePlacement {
+	return tensor.DevicePlacement{Type: tensor.CUDA, GPU: t.gpu}
+}
+
+func (t *Tensor) Data() interface{} {
+	return t.ptr
+}
+
+// Free explicitly frees the GPU memory associated with the tensor.
+// Use this for temporary tensors to avoid OOM due to delayed GC.
+func (t *Tensor) Free() {
+	if t == nil {
+		return
+	}
+	// Only the allocating tensor should explicitly free the CUDA allocation.
+	// Views/reshapes share storage and must not free it.
+	if t.storage != nil && t.ownsStorage {
+		// Clear finalizer so it doesn't run later
+		runtime.SetFinalizer(t.storage, nil)
+		_ = C.cuda_set_device(C.int(t.gpu))
+		C.cuda_free(t.storage.ptr)
+	}
+	t.storage = nil
+	t.ptr = nil
+}
+
+func (t *Tensor) Add(other tensor.Tensor) error {
+	o, ok := other.(*Tensor)
+	if !ok {
+		return errors.New("other must be CUDA tensor")
+	}
+	if t.dtype != tensor.Float32 || o.dtype != tensor.Float32 {
+		return errors.New("Add only supports Float32")
+	}
+	if t.shape.NumElements() != o.shape.NumElements() {
+		return errors.New("shape mismatch")
+	}
+
+	// Calls in-place add: t += o
+	if ret := C.cuda_set_device(C.int(t.gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_add_f32((*C.float)(t.ptr), (*C.float)(o.ptr), C.size_t(t.shape.NumElements()))
+	if ret != 0 {
+		return errors.New("cuda add failed")
+	}
+	return nil
+}
+
+func PagedAttentionBatch(Q, kBlocksFlatDev, vBlocksFlatDev, blockOffsetsDev, kvLensDev, queryPosDev, out unsafe.Pointer, numTokens, numHeads, numKVHeads, headDim, blockSize int, scale float32, maxKvLen int, gpu int) error {
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_paged_attention_batch_f32(
+		(*C.float)(Q),
+		(**C.float)(kBlocksFlatDev),
+		(**C.float)(vBlocksFlatDev),
+		(*C.int)(blockOffsetsDev),
+		(*C.int)(kvLensDev),
+		(*C.int)(queryPosDev),
+		(*C.float)(out),
+		C.int(numTokens),
+		C.int(numHeads), C.int(numKVHeads), C.int(headDim),
+		C.int(blockSize),
+		C.float(scale),
+		C.int(maxKvLen),
+	)
+	if ret != 0 {
+		return errors.New("cuda paged attention batch failed")
+	}
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	return nil
+}
+
+func (t *Tensor) Mul(other tensor.Tensor) error {
+	o, ok := other.(*Tensor)
+	if !ok {
+		return errors.New("other must be CUDA tensor")
+	}
+	if t.dtype != tensor.Float32 || o.dtype != tensor.Float32 {
+		return errors.New("Mul only supports Float32")
+	}
+	if t.shape.NumElements() != o.shape.NumElements() {
+		return errors.New("shape mismatch")
+	}
+
+	if ret := C.cuda_set_device(C.int(t.gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_mul_f32((*C.float)(t.ptr), (*C.float)(o.ptr), C.size_t(t.shape.NumElements()))
+	if ret != 0 {
+		return errors.New("cuda mul failed")
+	}
+	return nil
+}
+
+func (t *Tensor) MatMul(other tensor.Tensor, out tensor.Tensor) error {
+	B, ok := other.(*Tensor)
+	if !ok {
+		return errors.New("other must be CUDA tensor")
+	}
+	C_out, ok := out.(*Tensor)
+	if !ok {
+		return errors.New("out must be CUDA tensor")
+	}
+	if t.dtype != tensor.Float32 || B.dtype != tensor.Float32 || C_out.dtype != tensor.Float32 {
+		return errors.New("MatMul only supports Float32")
+	}
+
+	if len(t.shape) != 2 || len(B.shape) != 2 || len(C_out.shape) != 2 {
+		return errors.New("only 2D matmul")
+	}
+
+	M := t.shape[0]
+	K := t.shape[1]
+	// We use NT matmul (A @ B^T), so B is expected to be [N, K]
+	N := B.shape[0]
+	K2 := B.shape[1]
+
+	if K != K2 {
+		return fmt.Errorf("k dimension mismatch: A[%d,%d] vs B[%d,%d]", M, K, N, K2)
+	}
+
+	if C_out.shape[0] != M || C_out.shape[1] != N {
+		return fmt.Errorf("out shape mismatch: expected [%d,%d], got [%d,%d]", M, N, C_out.shape[0], C_out.shape[1])
+	}
+
+	if ret := C.cuda_set_device(C.int(t.gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_matmul_f32_nt(
+		(*C.float)(t.ptr),
+		(*C.float)(B.ptr),
+		(*C.float)(C_out.ptr),
+		C.int(M), C.int(K), C.int(N),
+	)
+	if ret != 0 {
+		return errors.New("cuda matmul failed")
+	}
+
+	return nil
+}
+
+// Reshape creates a view (shared storage) with new shape.
+// The new tensor shares the same underlying Storage, so memory
+// is only freed when all tensors sharing this storage are GC'd.
+func (t *Tensor) Reshape(shape tensor.Shape) (tensor.Tensor, error) {
+	if shape.NumElements() != t.shape.NumElements() {
+		return nil, errors.New("num elements mismatch")
+	}
+
+	// Share the same storage - Go's GC handles ref counting for us
+	return &Tensor{
+		shape:       shape,
+		dtype:       t.dtype,
+		storage:     t.storage, // Shared reference
+		ptr:         t.ptr,
+		gpu:         t.gpu,
+		ownsStorage: false,
+	}, nil
+}
+
+// ViewAt returns a view into the tensor starting at the given byte offset.
+// The returned tensor shares storage and does not allocate.
+func (t *Tensor) ViewAt(shape tensor.Shape, offsetBytes uintptr) (*Tensor, error) {
+	if t == nil {
+		return nil, errors.New("nil tensor")
+	}
+	if offsetBytes%uintptr(t.dtype.Size()) != 0 {
+		return nil, fmt.Errorf("offset %d not aligned to dtype size %d", offsetBytes, t.dtype.Size())
+	}
+
+	newPtr := unsafe.Pointer(uintptr(t.ptr) + offsetBytes)
+	return &Tensor{
+		shape:       shape,
+		dtype:       t.dtype,
+		storage:     t.storage,
+		ptr:         newPtr,
+		gpu:         t.gpu,
+		ownsStorage: false,
+	}, nil
+}
+
+func (t *Tensor) View(shape tensor.Shape) (tensor.Tensor, error) {
+	return t.Reshape(shape)
+}
+
+func (t *Tensor) ToDevice(device tensor.DeviceType) (tensor.Tensor, error) {
+	if device == tensor.CUDA {
+		return t, nil
+	}
+	// TODO: support CUDA -> CPU
+	if device == tensor.CPU {
+		// We need to copy data back
+		// 1. Create CPU tensor
+		// 2. Memcpy D2H
+		// 3. Return CPU tensor
+		// This requires importing "makarna/pkg/backend/cpu". Circular dependency risk?
+		// No, `cpu` imports `tensor`, `cuda` imports `tensor`.
+		// But `cuda` cannot import `cpu` easily if `cpu` is intended to be the default.
+		// Actually it's fine if `cuda` imports `cpu`.
+		return nil, errors.New("ToDevice(CPU) not implemented here yet, use helper")
+	}
+	return nil, errors.New("unknown device")
+}
+
+func (t *Tensor) CopyFrom(data interface{}) error {
+	if t.dtype != tensor.Float32 {
+		return errors.New("CopyFrom only supports Float32")
+	}
+	// Assuming data is []float32 on Host
+	src, ok := data.([]float32)
+	if !ok {
+		return errors.New("data must be []float32")
+	}
+	size := len(src) * 4
+	if size != t.shape.NumElements()*t.dtype.Size() {
+		return errors.New("size mismatch")
+	}
+
+	if ret := C.cuda_set_device(C.int(t.gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+
+	start := time.Now()
+	ret := C.cuda_memcpy_h2d(t.ptr, unsafe.Pointer(&src[0]), C.size_t(size))
+	if ret != 0 {
+		runtime.KeepAlive(src)
+		runtime.KeepAlive(t)
+		return errors.New("cuda memcpy failed")
+	}
+	profile.RecordTransfer("CopyFrom/H2D", profile.EventH2D, int64(size), time.Since(start), t.gpu)
+	runtime.KeepAlive(src)
+	runtime.KeepAlive(t)
+	return nil
+}
+
+// Helper to copy back to host
+func (t *Tensor) CopyToHost(dst []float32) error {
+	if t.dtype != tensor.Float32 {
+		return errors.New("CopyToHost only supports Float32")
+	}
+	size := len(dst) * 4
+	if size != t.shape.NumElements()*4 {
+		return errors.New("size mismatch")
+	}
+	if ret := C.cuda_set_device(C.int(t.gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+
+	start := time.Now()
+	ret := C.cuda_memcpy_d2h(unsafe.Pointer(&dst[0]), t.ptr, C.size_t(size))
+	if ret != 0 {
+		runtime.KeepAlive(dst)
+		runtime.KeepAlive(t)
+		return errors.New("cuda memcpy d2h failed")
+	}
+	profile.RecordTransfer("CopyToHost/D2H", profile.EventD2H, int64(size), time.Since(start), t.gpu)
+	runtime.KeepAlive(dst)
+	runtime.KeepAlive(t)
+	return nil
+}
+
+func (t *Tensor) CopyToInt32(dst []int32) error {
+	if t.dtype != tensor.Int32 {
+		return errors.New("CopyToInt32 only supports Int32")
+	}
+	size := len(dst) * 4
+	if size != t.shape.NumElements()*4 {
+		return errors.New("size mismatch")
+	}
+	if ret := C.cuda_set_device(C.int(t.gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_memcpy_d2h(unsafe.Pointer(&dst[0]), t.ptr, C.size_t(size))
+	if ret != 0 {
+		return errors.New("cuda memcpy d2h failed")
+	}
+	runtime.KeepAlive(dst)
+	runtime.KeepAlive(t)
+	return nil
+}
+
+// CopyPartialFrom copies a portion of host data to the tensor at a given offset.
+// dstOffset: offset in float32 elements from the start of the tensor
+// src: source data to copy from host
+func (t *Tensor) CopyPartialFrom(dstOffset int, src []float32) error {
+	if t.dtype != tensor.Float32 {
+		return errors.New("CopyPartialFrom only supports Float32")
+	}
+	if dstOffset+len(src) > t.shape.NumElements() {
+		return errors.New("partial copy would exceed tensor bounds")
+	}
+	if len(src) == 0 {
+		return nil
+	}
+
+	if ret := C.cuda_set_device(C.int(t.gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+
+	// Calculate destination pointer with offset
+	dstPtr := unsafe.Pointer(uintptr(t.ptr) + uintptr(dstOffset*4))
+	size := len(src) * 4
+
+	start := time.Now()
+	ret := C.cuda_memcpy_h2d(dstPtr, unsafe.Pointer(&src[0]), C.size_t(size))
+	if ret != 0 {
+		runtime.KeepAlive(src)
+		runtime.KeepAlive(t)
+		return errors.New("cuda memcpy partial failed")
+	}
+	profile.RecordTransfer("CopyPartialFrom/H2D", profile.EventH2D, int64(size), time.Since(start), t.gpu)
+	runtime.KeepAlive(src)
+	runtime.KeepAlive(t)
+	return nil
+}
+
+// CopyPartialFromDevice copies a portion from another CUDA tensor into this tensor.
+// Offsets and length are in float32 elements.
+func (t *Tensor) CopyPartialFromDevice(dstOffset int, src *Tensor, srcOffset int, length int) error {
+	if t.dtype != src.dtype {
+		return errors.New("dtype mismatch")
+	}
+	if dstOffset+length > t.shape.NumElements() {
+		return errors.New("dst offset/length exceed tensor bounds")
+	}
+	if srcOffset+length > src.shape.NumElements() {
+		return errors.New("src offset/length exceed tensor bounds")
+	}
+	if length == 0 {
+		return nil
+	}
+	if ret := C.cuda_set_device(C.int(t.gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+
+	start := time.Now()
+	eltSize := t.dtype.Size()
+	dstPtr := unsafe.Pointer(uintptr(t.ptr) + uintptr(dstOffset*eltSize))
+	srcPtr := unsafe.Pointer(uintptr(src.ptr) + uintptr(srcOffset*eltSize))
+	size := C.size_t(length * eltSize)
+	if ret := C.cuda_memcpy_d2d(dstPtr, srcPtr, size); ret != 0 {
+		runtime.KeepAlive(src)
+		runtime.KeepAlive(t)
+		return errors.New("cuda memcpy d2d failed")
+	}
+	profile.RecordTransfer("CopyPartialFromDevice/D2D", profile.EventD2D, int64(length*eltSize), time.Since(start), t.gpu)
+	runtime.KeepAlive(src)
+	runtime.KeepAlive(t)
+	return nil
+}
+
+func CastF32ToF16(srcF32, dstF16 unsafe.Pointer, n int, gpu int) error {
+	if n <= 0 {
+		return nil
+	}
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	if ret := C.cuda_cast_f32_to_f16((*C.float)(srcF32), (*C.ushort)(dstF16), C.int(n)); ret != 0 {
+		return errors.New("cuda cast f32->f16 failed")
+	}
+	return nil
+}
+
+func PagedAttentionF32F16KV(Q, kBlocksDev, vBlocksDev, out unsafe.Pointer, seqLen, kvLen, numHeads, numKVHeads, headDim, blockSize int, scale float32, startPos int, gpu int) error {
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_paged_attention_f32_f16kv(
+		(*C.float)(Q),
+		(**C.ushort)(kBlocksDev),
+		(**C.ushort)(vBlocksDev),
+		(*C.float)(out),
+		C.int(seqLen), C.int(kvLen),
+		C.int(numHeads), C.int(numKVHeads), C.int(headDim),
+		C.int(blockSize),
+		C.float(scale), C.int(startPos),
+	)
+	if ret != 0 {
+		return errors.New("cuda paged attention f16kv failed")
+	}
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	return nil
+}
+
+func PagedAttentionBatchF32F16KV(Q, kBlocksFlatDev, vBlocksFlatDev, blockOffsetsDev, kvLensDev, queryPosDev, out unsafe.Pointer, numTokens, numHeads, numKVHeads, headDim, blockSize int, scale float32, maxKvLen int, gpu int) error {
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_paged_attention_batch_f32_f16kv(
+		(*C.float)(Q),
+		(**C.ushort)(kBlocksFlatDev),
+		(**C.ushort)(vBlocksFlatDev),
+		(*C.int)(blockOffsetsDev),
+		(*C.int)(kvLensDev),
+		(*C.int)(queryPosDev),
+		(*C.float)(out),
+		C.int(numTokens),
+		C.int(numHeads), C.int(numKVHeads), C.int(headDim),
+		C.int(blockSize),
+		C.float(scale),
+		C.int(maxKvLen),
+	)
+	if ret != 0 {
+		return errors.New("cuda paged attention batch f16kv failed")
+	}
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	return nil
+}
+
+// PagedAttentionRoPEF32F16KV runs paged attention with fused RoPE inside the kernel.
+// Expects un-rotated Q and un-rotated K blocks; V blocks are unchanged.
+func PagedAttentionRoPEF32F16KV(Q, kBlocksDev, vBlocksDev, out unsafe.Pointer, seqLen, kvLen, numHeads, numKVHeads, headDim, blockSize int, scale float32, startPos int, theta float32, gpu int) error {
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_paged_attention_rope_f32_f16kv(
+		(*C.float)(Q),
+		(**C.ushort)(kBlocksDev),
+		(**C.ushort)(vBlocksDev),
+		(*C.float)(out),
+		C.int(seqLen), C.int(kvLen),
+		C.int(numHeads), C.int(numKVHeads), C.int(headDim),
+		C.int(blockSize),
+		C.float(scale), C.int(startPos),
+		C.float(theta),
+	)
+	if ret != 0 {
+		return errors.New("cuda paged attention rope f16kv failed")
+	}
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	return nil
+}
+
+// PagedAttentionBatchRoPEF32F16KV runs batched paged attention with fused RoPE inside the kernel.
+// Expects un-rotated Q and un-rotated K blocks; V blocks are unchanged.
+func PagedAttentionBatchRoPEF32F16KV(Q, kBlocksFlatDev, vBlocksFlatDev, blockOffsetsDev, kvLensDev, queryPosDev, out unsafe.Pointer, numTokens, numHeads, numKVHeads, headDim, blockSize int, scale float32, maxKvLen int, theta float32, gpu int) error {
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_paged_attention_rope_batch_f32_f16kv(
+		(*C.float)(Q),
+		(**C.ushort)(kBlocksFlatDev),
+		(**C.ushort)(vBlocksFlatDev),
+		(*C.int)(blockOffsetsDev),
+		(*C.int)(kvLensDev),
+		(*C.int)(queryPosDev),
+		(*C.float)(out),
+		C.int(numTokens),
+		C.int(numHeads), C.int(numKVHeads), C.int(headDim),
+		C.int(blockSize),
+		C.float(scale),
+		C.int(maxKvLen),
+		C.float(theta),
+	)
+	if ret != 0 {
+		return errors.New("cuda paged attention batch rope f16kv failed")
+	}
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	return nil
+}
+
+// Available returns whether CUDA is available
+func Available() bool {
+	return true
+}
+
+// MemoryInfo returns (total, free) bytes for the current CUDA device.
+func MemoryInfo() (total uint64, free uint64, err error) {
+	var cFree, cTotal C.size_t
+	ret := C.cuda_mem_info(&cFree, &cTotal)
+	if ret != 0 {
+		return 0, 0, errors.New("cuda_mem_info failed")
+	}
+	return uint64(cTotal), uint64(cFree), nil
+}
+
+// MemoryInfoDevice returns (total, free) bytes for the given CUDA device.
+func MemoryInfoDevice(gpu int) (total uint64, free uint64, err error) {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return 0, 0, errors.New("failed to set cuda device")
+	}
+	var cFree, cTotal C.size_t
+	ret := C.cuda_mem_info(&cFree, &cTotal)
+	if ret != 0 {
+		return 0, 0, errors.New("cuda_mem_info failed")
+	}
+	return uint64(cTotal), uint64(cFree), nil
+}
+
+// DeviceCount returns the number of visible CUDA devices.
+func DeviceCount() (int, error) {
+	var cCount C.int
+	ret := C.cuda_device_count(&cCount)
+	if ret != 0 {
+		return 0, errors.New("cuda_device_count failed")
+	}
+	if cCount < 0 {
+		return 0, errors.New("cuda_device_count returned negative")
+	}
+	return int(cCount), nil
+}
+
+// Synchronize waits for all queued work on the given GPU.
+// Use when explicit host/device coordination is required.
+func Synchronize(gpu int) error {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	if ret := C.cuda_synchronize(); ret != 0 {
+		return errors.New("cuda synchronize failed")
+	}
+	return nil
+}
+
+// ============================================================
+// Neural Network Operations
+// ============================================================
+
+// RMSNorm applies RMS normalization in-place on GPU
+// x: [seqLen, dim] device pointer, w: [dim] device pointer
+func RMSNorm(x, w unsafe.Pointer, seqLen, dim int, eps float32, gpu int) error {
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_rmsnorm_f32((*C.float)(x), (*C.float)(w), C.int(seqLen), C.int(dim), C.float(eps))
+	if ret != 0 {
+		return errors.New("cuda rmsnorm failed")
+	}
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	return nil
+}
+
+// RoPE applies rotary positional embeddings in-place
+// x: [seqLen, numHeads * headDim] device pointer
+// positions: [seqLen] device pointer (int32)
+func RoPE(x, positions unsafe.Pointer, seqLen, numHeads, headDim int, theta float32, gpu int) error {
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_rope_f32((*C.float)(x), (*C.int)(positions), C.int(seqLen), C.int(numHeads), C.int(headDim), C.float(theta))
+	if ret != 0 {
+		return errors.New("cuda rope failed")
+	}
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	return nil
+}
+
+// RoPESingle runs RoPE for a single token at a specific position.
+func RoPESingle(x unsafe.Pointer, pos, numHeads, headDim int, theta float32, gpu int) error {
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_rope_f32_single((*C.float)(x), C.int(pos), C.int(numHeads), C.int(headDim), C.float(theta))
+	if ret != 0 {
+		return errors.New("cuda rope single failed")
+	}
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	return nil
+}
+
+// Softmax applies softmax along last dimension in-place
+func Softmax(x unsafe.Pointer, rows, cols int, gpu int) error {
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_softmax_f32((*C.float)(x), C.int(rows), C.int(cols))
+	if ret != 0 {
+		return errors.New("cuda softmax failed")
+	}
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	return nil
+}
+
+// SiLU applies SiLU activation in-place: x = x * sigmoid(x)
+func SiLU(x unsafe.Pointer, n int, gpu int) error {
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_silu_f32((*C.float)(x), C.size_t(n))
+	if ret != 0 {
+		return errors.New("cuda silu failed")
+	}
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	return nil
+}
+
+// MulInplace performs element-wise a = a * b
+func MulInplace(a, b unsafe.Pointer, n int, gpu int) error {
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_mul_inplace_f32((*C.float)(a), (*C.float)(b), C.size_t(n))
+	if ret != 0 {
+		return errors.New("cuda mul inplace failed")
+	}
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	return nil
+}
+
+// Copy copies GPU memory: dst = src
+func Copy(dst, src unsafe.Pointer, n int, gpu int) error {
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_copy_f32((*C.float)(dst), (*C.float)(src), C.size_t(n))
+	if ret != 0 {
+		return errors.New("cuda copy failed")
+	}
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	return nil
+}
+
+func KDACausalShortConv1D(x, state, w unsafe.Pointer, tokens, projSize, kernel int, gpu int) error {
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_kda_causal_short_conv1d_f32(
+		(*C.float)(x),
+		(*C.float)(state),
+		(*C.float)(w),
+		C.int(tokens),
+		C.int(projSize),
+		C.int(kernel),
+	)
+	if ret != 0 {
+		return errors.New("cuda kda causal short conv1d failed")
+	}
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	return nil
+}
+
+func L2NormHeads(q, k unsafe.Pointer, tokens, numHeads, headDim int, eps float32, gpu int) error {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_l2norm_heads_f32((*C.float)(q), (*C.float)(k), C.int(tokens), C.int(numHeads), C.int(headDim), C.float(eps))
+	if ret != 0 {
+		return errors.New("cuda l2norm heads failed")
+	}
+	return nil
+}
+
+func KDAGate(g, aLog, dtBias, out unsafe.Pointer, tokens, numHeads, headDim int, gpu int) error {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_kda_gate_f32((*C.float)(g), (*C.float)(aLog), (*C.float)(dtBias), (*C.float)(out), C.int(tokens), C.int(numHeads), C.int(headDim))
+	if ret != 0 {
+		return errors.New("cuda kda gate failed")
+	}
+	return nil
+}
+
+func KDARecurrent(q, k, v, g, beta, state unsafe.Pointer, tokens, numHeads, headDim int, gpu int) error {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_kda_recurrent_f32((*C.float)(q), (*C.float)(k), (*C.float)(v), (*C.float)(g), (*C.float)(beta), (*C.float)(state), C.int(tokens), C.int(numHeads), C.int(headDim))
+	if ret != 0 {
+		return errors.New("cuda kda recurrent failed")
+	}
+	return nil
+}
+
+func RMSNormGated(out, g, weight unsafe.Pointer, n, headDim int, eps float32, gpu int) error {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_rmsnorm_gated_f32((*C.float)(out), (*C.float)(g), (*C.float)(weight), C.int(n), C.int(headDim), C.float(eps))
+	if ret != 0 {
+		return errors.New("cuda rmsnorm gated failed")
+	}
+	return nil
+}
+
+func Sigmoid(x unsafe.Pointer, n int, gpu int) error {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_sigmoid_f32((*C.float)(x), C.int(n))
+	if ret != 0 {
+		return errors.New("cuda sigmoid failed")
+	}
+	return nil
+}
+
+func SoftmaxRows(x unsafe.Pointer, rows, cols int, gpu int) error {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_softmax_rows_f32((*C.float)(x), C.int(rows), C.int(cols))
+	if ret != 0 {
+		return errors.New("cuda softmax rows failed")
+	}
+	return nil
+}
+
+func TopKPerRow(scores unsafe.Pointer, indices unsafe.Pointer, values unsafe.Pointer, rows, cols, k int, gpu int) error {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_topk_per_row_f32((*C.float)(scores), (*C.int)(indices), (*C.float)(values), C.int(rows), C.int(cols), C.int(k))
+	if ret != 0 {
+		return errors.New("cuda topk per row failed")
+	}
+	return nil
+}
+
+// Attention computes full causal attention on GPU
+// Q: [seqLen, numHeads * headDim]
+// K, V: [kvLen, numKVHeads * headDim]
+// out: [seqLen, numHeads * headDim]
+func Attention(Q, K, V, out unsafe.Pointer, seqLen, kvLen, numHeads, numKVHeads, headDim int, scale float32, startPos int, gpu int) error {
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_attention_f32(
+		(*C.float)(Q), (*C.float)(K), (*C.float)(V), (*C.float)(out),
+		C.int(seqLen), C.int(kvLen), C.int(numHeads), C.int(numKVHeads), C.int(headDim),
+		C.float(scale), C.int(startPos),
+	)
+	if ret != 0 {
+		return errors.New("cuda attention failed")
+	}
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	return nil
+}
+
+func PagedAttention(Q, kBlocksDev, vBlocksDev, out unsafe.Pointer, seqLen, kvLen, numHeads, numKVHeads, headDim, blockSize int, scale float32, startPos int, gpu int) error {
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_paged_attention_f32(
+		(*C.float)(Q),
+		(**C.float)(kBlocksDev),
+		(**C.float)(vBlocksDev),
+		(*C.float)(out),
+		C.int(seqLen), C.int(kvLen), C.int(numHeads), C.int(numKVHeads), C.int(headDim),
+		C.int(blockSize),
+		C.float(scale), C.int(startPos),
+	)
+	if ret != 0 {
+		return errors.New("cuda paged attention failed")
+	}
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	return nil
+}
+
+// AttentionTimed runs attention and returns kernel time in milliseconds.
+// Intended for profiling/debugging only (it synchronizes internally).
+func AttentionTimed(Q, K, V, out unsafe.Pointer, seqLen, kvLen, numHeads, numKVHeads, headDim int, scale float32, startPos int, gpu int) (float32, error) {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return 0, errors.New("failed to set cuda device")
+	}
+	var ms C.float
+	ret := C.cuda_attention_f32_timed(
+		(*C.float)(Q), (*C.float)(K), (*C.float)(V), (*C.float)(out),
+		C.int(seqLen), C.int(kvLen), C.int(numHeads), C.int(numKVHeads), C.int(headDim),
+		C.float(scale), C.int(startPos), &ms,
+	)
+	if ret != 0 {
+		return 0, errors.New("cuda attention timed failed")
+	}
+	return float32(ms), nil
+}
+
+// AddInplace performs element-wise a = a + b
+func AddInplace(a, b unsafe.Pointer, n int, gpu int) error {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_add_f32((*C.float)(a), (*C.float)(b), C.size_t(n))
+	if ret != 0 {
+		return errors.New("cuda add failed")
+	}
+	return nil
+}
+
+// ============================================================
+// Dequantization Operations
+// ============================================================
+
+// DequantQ8K dequantizes Q8_K blocks on GPU
+// blocks: device pointer to Q8_K data
+// out: device pointer to output float32 (numBlocks * 256 elements)
+func DequantQ8K(blocks unsafe.Pointer, out unsafe.Pointer, numBlocks int, gpu int) error {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_dequant_q8k(blocks, (*C.float)(out), C.int(numBlocks))
+	if ret != 0 {
+		return errors.New("cuda dequant q8k failed")
+	}
+	return nil
+}
+
+func DequantQ4K(blocks unsafe.Pointer, out unsafe.Pointer, numBlocks int, gpu int) error {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_dequant_q4k(blocks, (*C.float)(out), C.int(numBlocks))
+	if ret != 0 {
+		return errors.New("cuda dequant q4k failed")
+	}
+	return nil
+}
+
+func DequantQ5K(blocks unsafe.Pointer, out unsafe.Pointer, numBlocks int, gpu int) error {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_dequant_q5k(blocks, (*C.float)(out), C.int(numBlocks))
+	if ret != 0 {
+		return errors.New("cuda dequant q5k failed")
+	}
+	return nil
+}
+
+// DequantQ6K dequantizes Q6_K blocks on GPU
+func DequantQ6K(blocks unsafe.Pointer, out unsafe.Pointer, numBlocks int, gpu int) error {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_dequant_q6k(blocks, (*C.float)(out), C.int(numBlocks))
+	if ret != 0 {
+		return errors.New("cuda dequant q6k failed")
+	}
+	return nil
+}
+
+func DequantQ3K(blocks unsafe.Pointer, out unsafe.Pointer, numBlocks int, gpu int) error {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_dequant_q3k(blocks, (*C.float)(out), C.int(numBlocks))
+	if ret != 0 {
+		return errors.New("cuda dequant q3k failed")
+	}
+	return nil
+}
+
+func DequantQ2K(blocks unsafe.Pointer, out unsafe.Pointer, numBlocks int, gpu int) error {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_dequant_q2k(blocks, (*C.float)(out), C.int(numBlocks))
+	if ret != 0 {
+		return errors.New("cuda dequant q2k failed")
+	}
+	return nil
+}
+
+// MatMulQ8K performs C = A @ dequant(B) where B is Q8_K quantized
+func MatMulQ8K(A unsafe.Pointer, B unsafe.Pointer, C unsafe.Pointer, M, K, N, gpu int) error {
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_matmul_f32_q8k((*C.float)(A), B, (*C.float)(C), C.int(M), C.int(K), C.int(N))
+	if ret != 0 {
+		return errors.New("cuda matmul q8k failed")
+	}
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	return nil
+}
+
+func MatMulQ5K(A unsafe.Pointer, B unsafe.Pointer, C unsafe.Pointer, M, K, N, gpu int) error {
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_matmul_f32_q5k((*C.float)(A), B, (*C.float)(C), C.int(M), C.int(K), C.int(N))
+	if ret != 0 {
+		return errors.New("cuda matmul q5k failed")
+	}
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	return nil
+}
+
+func MatMulQ4K(A unsafe.Pointer, B unsafe.Pointer, Cptr unsafe.Pointer, M, K, N, gpu int) error {
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_matmul_f32_q4k((*C.float)(A), B, (*C.float)(Cptr), C.int(M), C.int(K), C.int(N))
+	if ret != 0 {
+		return errors.New("cuda matmul q4k failed")
+	}
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	return nil
+}
+
+func MatMulQ2K(aPtr, bPtr, cPtr unsafe.Pointer, m, k, n int, gpu int) error {
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	if k%256 != 0 {
+		return fmt.Errorf("MatMulQ2K: K must be multiple of 256, got %d", k)
+	}
+	ret := C.cuda_matmul_f32_q2k((*C.float)(aPtr), bPtr, (*C.float)(cPtr), C.int(m), C.int(k), C.int(n))
+	if ret != 0 {
+		return errors.New("cuda matmul q2k failed")
+	}
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	return nil
+}
+
+func MatMulQ3K(aPtr, bPtr, cPtr unsafe.Pointer, m, k, n int, gpu int) error {
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	if k%256 != 0 {
+		return fmt.Errorf("MatMulQ3K: K must be multiple of 256, got %d", k)
+	}
+	ret := C.cuda_matmul_f32_q3k((*C.float)(aPtr), bPtr, (*C.float)(cPtr), C.int(m), C.int(k), C.int(n))
+	if ret != 0 {
+		return errors.New("cuda matmul q3k failed")
+	}
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	return nil
+}
+
+func MatMulQ6K(aPtr, bPtr, cPtr unsafe.Pointer, m, k, n int, gpu int) error {
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	if k%256 != 0 {
+		return fmt.Errorf("MatMulQ6K: K must be multiple of 256, got %d", k)
+	}
+	ret := C.cuda_matmul_f32_q6k((*C.float)(aPtr), bPtr, (*C.float)(cPtr), C.int(m), C.int(k), C.int(n))
+	if ret != 0 {
+		return errors.New("cuda matmul q6k failed")
+	}
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	return nil
+}
+
+func MatMulF32(A, B, Cptr unsafe.Pointer, M, K, N, gpu int) error {
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_matmul_f32_nt(
+		(*C.float)(A),
+		(*C.float)(B),
+		(*C.float)(Cptr),
+		C.int(M), C.int(K), C.int(N),
+	)
+	if ret != 0 {
+		return errors.New("cuda matmul f32 failed")
+	}
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	return nil
+}
+
+// MatMulF16 performs C = A @ B^T where A and B are float16 (stored as uint16),
+// and C is float32 output.
+func MatMulF16(A, B, Cptr unsafe.Pointer, M, K, N, gpu int) error {
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_matmul_f16_nt(
+		(*C.ushort)(A),
+		(*C.ushort)(B),
+		(*C.float)(Cptr),
+		C.int(M), C.int(K), C.int(N),
+	)
+	if ret != 0 {
+		return errors.New("cuda matmul f16 failed")
+	}
+	if err := syncIfProfiling(gpu); err != nil {
+		return err
+	}
+	return nil
+}
+
+// FP16 Input MatMul variants - 2x memory bandwidth for activations
+// A is FP16, B is quantized, C is FP32 output
+
+func MatMulF16Q8K(A, B, Cptr unsafe.Pointer, M, K, N, gpu int) error {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_matmul_f16_q8k(A, B, (*C.float)(Cptr), C.int(M), C.int(K), C.int(N))
+	if ret != 0 {
+		return errors.New("cuda matmul f16 q8k failed")
+	}
+	return nil
+}
+
+func MatMulF16Q4K(A, B, Cptr unsafe.Pointer, M, K, N, gpu int) error {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_matmul_f16_q4k(A, B, (*C.float)(Cptr), C.int(M), C.int(K), C.int(N))
+	if ret != 0 {
+		return errors.New("cuda matmul f16 q4k failed")
+	}
+	return nil
+}
+
+func MatMulF16Q5K(A, B, Cptr unsafe.Pointer, M, K, N, gpu int) error {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_matmul_f16_q5k(A, B, (*C.float)(Cptr), C.int(M), C.int(K), C.int(N))
+	if ret != 0 {
+		return errors.New("cuda matmul f16 q5k failed")
+	}
+	return nil
+}
+
+func MatMulF16Q2K(A, B, Cptr unsafe.Pointer, M, K, N, gpu int) error {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_matmul_f16_q2k(A, B, (*C.float)(Cptr), C.int(M), C.int(K), C.int(N))
+	if ret != 0 {
+		return errors.New("cuda matmul f16 q2k failed")
+	}
+	return nil
+}
+
+func MatMulF16Q3K(A, B, Cptr unsafe.Pointer, M, K, N, gpu int) error {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_matmul_f16_q3k(A, B, (*C.float)(Cptr), C.int(M), C.int(K), C.int(N))
+	if ret != 0 {
+		return errors.New("cuda matmul f16 q3k failed")
+	}
+	return nil
+}
+
+func MatMulF16Q6K(A, B, Cptr unsafe.Pointer, M, K, N, gpu int) error {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_matmul_f16_q6k(A, B, (*C.float)(Cptr), C.int(M), C.int(K), C.int(N))
+	if ret != 0 {
+		return errors.New("cuda matmul f16 q6k failed")
+	}
+	return nil
+}
+
+// UploadQ8K uploads Q8_K blocks from host to GPU
+func UploadQ8K(hostData []byte, numBlocks int, gpu int) (unsafe.Pointer, error) {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return nil, errors.New("failed to set cuda device")
+	}
+
+	size := len(hostData)
+	ptr := C.cuda_malloc(C.size_t(size))
+	if ptr == nil {
+		return nil, errors.New("cuda malloc failed for Q8K")
+	}
+
+	ret := C.cuda_memcpy_h2d(ptr, unsafe.Pointer(&hostData[0]), C.size_t(size))
+	if ret != 0 {
+		C.cuda_free(ptr)
+		return nil, errors.New("cuda memcpy h2d failed for Q8K")
+	}
+
+	return ptr, nil
+}
+
+func AllocAndCopyPtrTable(ptrs []uintptr, gpu int) (unsafe.Pointer, error) {
+	if len(ptrs) == 0 {
+		return nil, errors.New("empty ptr table")
+	}
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return nil, errors.New("failed to set cuda device")
+	}
+
+	size := len(ptrs) * int(unsafe.Sizeof(uintptr(0)))
+	ptr := C.cuda_malloc(C.size_t(size))
+	if ptr == nil {
+		return nil, errors.New("cuda malloc failed for ptr table")
+	}
+
+	ret := C.cuda_memcpy_h2d(ptr, unsafe.Pointer(&ptrs[0]), C.size_t(size))
+	if ret != 0 {
+		C.cuda_free(ptr)
+		return nil, errors.New("cuda memcpy h2d failed for ptr table")
+	}
+
+	return ptr, nil
+}
+
+func UploadQ5K(hostData []byte, numBlocks int, gpu int) (unsafe.Pointer, error) {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return nil, errors.New("failed to set cuda device")
+	}
+
+	size := len(hostData)
+	ptr := C.cuda_malloc(C.size_t(size))
+	if ptr == nil {
+		return nil, errors.New("cuda malloc failed for Q5K")
+	}
+
+	ret := C.cuda_memcpy_h2d(ptr, unsafe.Pointer(&hostData[0]), C.size_t(size))
+	if ret != 0 {
+		C.cuda_free(ptr)
+		return nil, errors.New("cuda memcpy h2d failed for Q5K")
+	}
+
+	return ptr, nil
+}
+
+// UploadQ4K uploads Q4_K blocks from host to GPU
+func UploadQ4K(hostData []byte, numBlocks int, gpu int) (unsafe.Pointer, error) {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return nil, errors.New("failed to set cuda device")
+	}
+
+	size := len(hostData)
+	ptr := C.cuda_malloc(C.size_t(size))
+	if ptr == nil {
+		return nil, errors.New("cuda malloc failed for Q4K")
+	}
+
+	ret := C.cuda_memcpy_h2d(ptr, unsafe.Pointer(&hostData[0]), C.size_t(size))
+	if ret != 0 {
+		C.cuda_free(ptr)
+		return nil, errors.New("cuda memcpy h2d failed for Q4K")
+	}
+
+	return ptr, nil
+}
+
+// UploadQ2K uploads Q2_K blocks from host to GPU
+func UploadQ2K(hostData []byte, numBlocks int, gpu int) (unsafe.Pointer, error) {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return nil, errors.New("failed to set cuda device")
+	}
+
+	size := len(hostData)
+	ptr := C.cuda_malloc(C.size_t(size))
+	if ptr == nil {
+		return nil, errors.New("cuda malloc failed for Q2K")
+	}
+
+	ret := C.cuda_memcpy_h2d(ptr, unsafe.Pointer(&hostData[0]), C.size_t(size))
+	if ret != 0 {
+		C.cuda_free(ptr)
+		return nil, errors.New("cuda memcpy h2d failed for Q2K")
+	}
+
+	return ptr, nil
+}
+
+// UploadQ3K uploads Q3_K blocks from host to GPU
+func UploadQ3K(hostData []byte, numBlocks int, gpu int) (unsafe.Pointer, error) {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return nil, errors.New("failed to set cuda device")
+	}
+
+	size := len(hostData)
+	ptr := C.cuda_malloc(C.size_t(size))
+	if ptr == nil {
+		return nil, errors.New("cuda malloc failed for Q3K")
+	}
+
+	ret := C.cuda_memcpy_h2d(ptr, unsafe.Pointer(&hostData[0]), C.size_t(size))
+	if ret != 0 {
+		C.cuda_free(ptr)
+		return nil, errors.New("cuda memcpy h2d failed for Q3K")
+	}
+
+	return ptr, nil
+}
+
+// UploadQ6K uploads Q6_K blocks from host to GPU
+func UploadQ6K(hostData []byte, numBlocks int, gpu int) (unsafe.Pointer, error) {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return nil, errors.New("failed to set cuda device")
+	}
+
+	size := len(hostData)
+	ptr := C.cuda_malloc(C.size_t(size))
+	if ptr == nil {
+		return nil, errors.New("cuda malloc failed for Q6K")
+	}
+
+	ret := C.cuda_memcpy_h2d(ptr, unsafe.Pointer(&hostData[0]), C.size_t(size))
+	if ret != 0 {
+		C.cuda_free(ptr)
+		return nil, errors.New("cuda memcpy h2d failed for Q6K")
+	}
+
+	return ptr, nil
+}
+
+// MemcpyH2D copies data from host to device pointer.
+// dst: device pointer
+// src: host data (unsafe.Pointer to first element)
+// size: number of bytes
+// gpu: device id (must be active or will be set)
+func MemcpyH2D(dst, src unsafe.Pointer, size uintptr, gpu int) error {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_memcpy_h2d(dst, src, C.size_t(size))
+	if ret != 0 {
+		return errors.New("cuda memcpy h2d failed")
+	}
+	return nil
+}
+
+// MemcpyD2H copies data from device pointer to host pointer.
+func MemcpyD2H(dst, src unsafe.Pointer, size uintptr, gpu int) error {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_memcpy_d2h(dst, src, C.size_t(size))
+	if ret != 0 {
+		return errors.New("cuda memcpy d2h failed")
+	}
+	return nil
+}
+
+func MemcpyD2D(dst, src unsafe.Pointer, size uintptr, gpu int) error {
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_memcpy_d2d(dst, src, C.size_t(size))
+	if ret != 0 {
+		return errors.New("cuda memcpy d2d failed")
+	}
+	return nil
+}
+
+// TopKLogitsF32 computes per-block top-k on GPU (with repetition penalty applied)
+// and returns the concatenated candidate list on host (caller does final global top-k).
+func TopKLogitsF32(logits unsafe.Pointer, vocab int, repIDs []int32, repPenalty float32, k int, gpu int) ([]int32, []float32, int, error) {
+	if k <= 0 {
+		return nil, nil, 0, nil
+	}
+	if k > 64 {
+		return nil, nil, 0, fmt.Errorf("TopKLogitsF32: k too large: %d", k)
+	}
+	blocks := (vocab + 2048 - 1) / 2048
+	if blocks <= 0 {
+		blocks = 1
+	}
+	count := blocks * k
+
+	var repPtr unsafe.Pointer
+	if len(repIDs) > 0 {
+		p, err := AllocAndCopyInt32(repIDs, gpu)
+		if err != nil {
+			return nil, nil, 0, err
+		}
+		repPtr = p
+		defer FreeDevicePtr(repPtr)
+	}
+
+	// Device outputs
+	outIDsPtr := C.cuda_malloc(C.size_t(count * 4))
+	if outIDsPtr == nil {
+		return nil, nil, 0, errors.New("TopKLogitsF32: cuda malloc failed for outIDs")
+	}
+	defer C.cuda_free(outIDsPtr)
+
+	outScoresPtr := C.cuda_malloc(C.size_t(count * 4))
+	if outScoresPtr == nil {
+		return nil, nil, 0, errors.New("TopKLogitsF32: cuda malloc failed for outScores")
+	}
+	defer C.cuda_free(outScoresPtr)
+
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return nil, nil, 0, errors.New("failed to set cuda device")
+	}
+	ret := C.cuda_topk_logits_f32(
+		(*C.float)(logits),
+		C.int(vocab),
+		(*C.int)(repPtr),
+		C.int(len(repIDs)),
+		C.float(repPenalty),
+		C.int(k),
+		(*C.int)(outIDsPtr),
+		(*C.float)(outScoresPtr),
+	)
+	if ret != 0 {
+		return nil, nil, 0, errors.New("cuda topk logits failed")
+	}
+
+	ids := make([]int32, count)
+	scores := make([]float32, count)
+	if err := MemcpyD2H(unsafe.Pointer(&ids[0]), unsafe.Pointer(outIDsPtr), uintptr(count*4), gpu); err != nil {
+		return nil, nil, 0, err
+	}
+	if err := MemcpyD2H(unsafe.Pointer(&scores[0]), unsafe.Pointer(outScoresPtr), uintptr(count*4), gpu); err != nil {
+		return nil, nil, 0, err
+	}
+
+	return ids, scores, blocks, nil
+}
+
+// FreeDevicePtr frees a device pointer
+func FreeDevicePtr(ptr unsafe.Pointer) {
+	if ptr != nil {
+		C.cuda_free(ptr)
+	}
+}
+
+// Free is an alias for FreeDevicePtr for convenience
+func Free(ptr unsafe.Pointer) {
+	FreeDevicePtr(ptr)
+}
+
+// AllocAndCopyInt32 allocates GPU memory and copies int32 data to it
+// Returns raw device pointer (caller must Free it)
+func AllocAndCopyInt32(data []int32, gpu int) (unsafe.Pointer, error) {
+	if len(data) == 0 {
+		return nil, errors.New("empty data")
+	}
+
+	if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
+		return nil, errors.New("failed to set cuda device")
+	}
+
+	size := len(data) * 4 // 4 bytes per int32
+	ptr := C.cuda_malloc(C.size_t(size))
+	if ptr == nil {
+		return nil, errors.New("cuda malloc failed for int32 data")
+	}
+
+	ret := C.cuda_memcpy_h2d(ptr, unsafe.Pointer(&data[0]), C.size_t(size))
+	if ret != 0 {
+		C.cuda_free(ptr)
+		return nil, errors.New("cuda memcpy h2d failed for int32 data")
+	}
+
+	return ptr, nil
+}

+ 43 - 0
pkg/backend/cuda/cuda_common.cuh

@@ -0,0 +1,43 @@
+#ifndef MAKARNA_CUDA_COMMON_CUH
+#define MAKARNA_CUDA_COMMON_CUH
+
+#include "kernels.h"
+#define CUDA_API_PER_THREAD_DEFAULT_STREAM 1
+#include <cuda_runtime.h>
+#include <math.h>
+#include <stdio.h>
+
+#define CHECK_CUDA(call) \
+    do { \
+        cudaError_t err = call; \
+        if (err != cudaSuccess) { \
+            fprintf(stderr, "CUDA error at %s:%d: %s\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
+            return 1; \
+        } \
+    } while (0)
+
+// ============================================================
+// FP16 -> FP32 conversion (device function)
+// ============================================================
+__device__ __forceinline__ float fp16_to_fp32(unsigned short h) {
+    unsigned int sign = (h & 0x8000) << 16;
+    unsigned int exp = (h & 0x7C00) >> 10;
+    unsigned int mant = h & 0x03FF;
+    
+    if (exp > 0 && exp < 0x1F) {
+        // Normalized
+        unsigned int bits = sign | ((exp + 112) << 23) | (mant << 13);
+        return __int_as_float(bits);
+    }
+    if (exp == 0) {
+        if (mant == 0) return __int_as_float(sign); // Zero
+        // Denorm - rare case, simplified
+        float m = (float)mant / 1024.0f;
+        float val = m * 6.103515625e-05f; // 2^-14
+        return sign ? -val : val;
+    }
+    // Inf/NaN
+    return mant == 0 ? __int_as_float(sign | 0x7F800000) : __int_as_float(sign | 0x7FC00000);
+}
+
+#endif

+ 129 - 0
pkg/backend/cuda/cuda_dequant_other.cu

@@ -0,0 +1,129 @@
+#include "cuda_common.cuh"
+
+// ============================================================
+// Q6_K Dequantization Kernel
+// 256 elements, lower 4 bits + upper 2 bits
+// ============================================================
+__global__ void dequant_q6k_kernel(const BlockQ6_K* blocks, float* out, int numBlocks) {
+    int blockIdx_q = blockIdx.x;
+    int elemIdx = threadIdx.x; // 0-255
+    
+    if (blockIdx_q >= numBlocks) return;
+    
+    const BlockQ6_K* b = &blocks[blockIdx_q];
+    float d = fp16_to_fp32(b->d);
+    
+    // Position within 128-element halves
+    int half = elemIdx / 128;
+    int pos = elemIdx % 128;
+    const int is = elemIdx / 32;
+    const int iq = elemIdx % 32;
+    
+    int qlIdx = (is / 4) * 64 + (is % 2) * 32 + iq;
+    int qhIdx = (is / 4) * 32 + iq;
+    int scIdx = (is / 4) * 8 + (is % 4) * 2 + (iq / 16);
+    
+    unsigned char ql = b->ql[qlIdx];
+    unsigned char qh = b->qh[qhIdx];
+    
+    int shift_ql = ((is % 4) < 2) ? 0 : 4;
+    int shift_qh = (is % 4) * 2;
+    
+    int q = ((ql >> shift_ql) & 0xF) | (((qh >> shift_qh) & 3) << 4);
+    q -= 32;
+    
+    out[blockIdx_q * 256 + elemIdx] = d * (float)b->scales[scIdx] * (float)q;
+}
+
+int cuda_dequant_q6k(const void* blocks, float* out, int numBlocks) {
+    if (numBlocks == 0) return 0;
+    dequant_q6k_kernel<<<numBlocks, 256>>>((const BlockQ6_K*)blocks, out, numBlocks);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+// ============================================================
+// Q3_K Dequantization Kernel
+// ============================================================
+__device__ __forceinline__ signed char unpack_q3_scale(const unsigned char* packed, int idx) {
+    unsigned char sc;
+    if (idx < 8) {
+        sc = packed[idx] & 0xF;
+    } else {
+        sc = packed[idx - 8] >> 4;
+    }
+    sc |= ((packed[8 + (idx % 4)] >> (2 * (idx / 4))) & 0x3) << 4;
+    return (signed char)sc - 32;
+}
+
+__global__ void dequant_q3k_kernel(const BlockQ3_K* blocks, float* out, int numBlocks) {
+    int blockIdx_q = blockIdx.x;
+    int elemIdx = threadIdx.x;
+    
+    if (blockIdx_q >= numBlocks) return;
+    
+    const BlockQ3_K* b = &blocks[blockIdx_q];
+    float d = fp16_to_fp32(b->d);
+    
+    const int is = elemIdx / 32;
+    const int iq = elemIdx % 32;
+    
+    int qsIdx = (is / 4) * 32 + iq;
+    int hmaskIdx = iq;
+    
+    int scaleIdx = (is / 4) * 8 + (is % 4) * 2 + (iq / 16);
+    
+    int shift = (is % 4) * 2;
+    unsigned char m = 1 << ((is / 4) * 4 + (is % 4));
+    
+    signed char scale = unpack_q3_scale(b->scales, scaleIdx);
+    int qv = (b->qs[qsIdx] >> shift) & 0x3;
+    if ((b->hmask[hmaskIdx] & m) == 0) {
+        qv -= 4;
+    }
+    
+    out[blockIdx_q * 256 + elemIdx] = d * (float)scale * (float)qv;
+}
+
+int cuda_dequant_q3k(const void* blocks, float* out, int numBlocks) {
+    if (numBlocks == 0) return 0;
+    dequant_q3k_kernel<<<numBlocks, 256>>>((const BlockQ3_K*)blocks, out, numBlocks);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+// ============================================================
+// Q2_K Dequantization Kernel
+// ============================================================
+__global__ void dequant_q2k_kernel(const BlockQ2_K* blocks, float* out, int numBlocks) {
+    int blockIdx_q = blockIdx.x;
+    int elemIdx = threadIdx.x;
+    
+    if (blockIdx_q >= numBlocks) return;
+    
+    const BlockQ2_K* b = &blocks[blockIdx_q];
+    float d = fp16_to_fp32(b->d);
+    float dmin = fp16_to_fp32(b->dmin);
+    
+    const int is = elemIdx / 32;
+    const int iq = elemIdx % 32;
+    
+    int scIdx = (is / 4) * 8 + (is % 4) * 2 + (iq / 16);
+    int qsIdx = (is / 4) * 32 + iq;
+    int shift = (is % 4) * 2;
+    
+    unsigned char sc = b->scales[scIdx];
+    float dl = d * (float)(sc & 0xF);
+    float ml = dmin * (float)(sc >> 4);
+    
+    int val = (b->qs[qsIdx] >> shift) & 3;
+    
+    out[blockIdx_q * 256 + elemIdx] = dl * (float)val - ml;
+}
+
+int cuda_dequant_q2k(const void* blocks, float* out, int numBlocks) {
+    if (numBlocks == 0) return 0;
+    dequant_q2k_kernel<<<numBlocks, 256>>>((const BlockQ2_K*)blocks, out, numBlocks);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}

+ 50 - 0
pkg/backend/cuda/cuda_dequant_q4k.cu

@@ -0,0 +1,50 @@
+#include "cuda_common.cuh"
+
+// ============================================================
+// Q4_K Dequantization Kernel
+// 256 elements per block, 128 bytes of qs (2 elements per byte)
+// Complex scale unpacking
+// ============================================================
+__global__ void dequant_q4k_kernel(const BlockQ4_K* blocks, float* out, int numBlocks) {
+    int blockIdx_q = blockIdx.x;
+    int elemIdx = threadIdx.x; // 0-255
+    
+    if (blockIdx_q >= numBlocks) return;
+    
+    const BlockQ4_K* b = &blocks[blockIdx_q];
+    float d = fp16_to_fp32(b->d);
+    float dmin = fp16_to_fp32(b->dmin);
+    
+    // Unpack scales and mins
+    unsigned char sc[8], m[8];
+    #pragma unroll
+    for (int j = 0; j < 4; j++) {
+        sc[j] = b->scales[j] & 63;
+        m[j] = b->scales[j + 4] & 63;
+    }
+    #pragma unroll
+    for (int j = 4; j < 8; j++) {
+        sc[j] = (b->scales[j + 4] & 0xF) | ((b->scales[j - 4] >> 6) << 4);
+        m[j] = (b->scales[j + 4] >> 4) | ((b->scales[j] >> 6) << 4);
+    }
+    
+    // Which sub-block and position
+    int subBlock = elemIdx / 32;
+    int subPos = elemIdx % 32;
+    int qsIdx = (subBlock / 2) * 32 + subPos;
+    
+    unsigned char qs = b->qs[qsIdx];
+    int val = (subBlock % 2 == 0) ? (qs & 0xF) : (qs >> 4);
+    
+    float scale = d * (float)sc[subBlock];
+    float minVal = dmin * (float)m[subBlock];
+    
+    out[blockIdx_q * 256 + elemIdx] = (float)val * scale - minVal;
+}
+
+int cuda_dequant_q4k(const void* blocks, float* out, int numBlocks) {
+    if (numBlocks == 0) return 0;
+    dequant_q4k_kernel<<<numBlocks, 256>>>((const BlockQ4_K*)blocks, out, numBlocks);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}

+ 54 - 0
pkg/backend/cuda/cuda_dequant_q5k.cu

@@ -0,0 +1,54 @@
+#include "cuda_common.cuh"
+
+__global__ void dequant_q5k_kernel(const BlockQ5_K* blocks, float* out, int numBlocks) {
+    int blockIdx_q = blockIdx.x;
+    int elemIdx = threadIdx.x;
+
+    if (blockIdx_q >= numBlocks) return;
+
+    const BlockQ5_K* b = &blocks[blockIdx_q];
+    float d = fp16_to_fp32(b->d);
+    float dmin = fp16_to_fp32(b->dmin);
+
+    unsigned char sc[8], m[8];
+    #pragma unroll
+    for (int j = 0; j < 4; j++) {
+        sc[j] = b->scales[j] & 63;
+        m[j] = b->scales[j + 4] & 63;
+    }
+    #pragma unroll
+    for (int j = 4; j < 8; j++) {
+        sc[j] = (b->scales[j + 4] & 0xF) | ((b->scales[j - 4] >> 6) << 4);
+        m[j] = (b->scales[j + 4] >> 4) | ((b->scales[j] >> 6) << 4);
+    }
+
+    int subBlock = elemIdx / 32;
+    int subPos = elemIdx % 32;
+
+    int chunk = elemIdx / 64;
+    int posInChunk = elemIdx % 64;
+    int qsIdx = chunk * 32 + (posInChunk & 31);
+
+    unsigned char qs = b->qs[qsIdx];
+    int val;
+    unsigned char hb = b->qh[posInChunk & 31];
+    if (posInChunk < 32) {
+        val = (qs & 0xF);
+        val += ((hb >> (2 * chunk)) & 1) << 4;
+    } else {
+        val = (qs >> 4);
+        val += ((hb >> (2 * chunk + 1)) & 1) << 4;
+    }
+
+    float scale = d * (float)sc[subBlock];
+    float minVal = dmin * (float)m[subBlock];
+
+    out[blockIdx_q * 256 + elemIdx] = (float)val * scale - minVal;
+}
+
+int cuda_dequant_q5k(const void* blocks, float* out, int numBlocks) {
+    if (numBlocks == 0) return 0;
+    dequant_q5k_kernel<<<numBlocks, 256>>>((const BlockQ5_K*)blocks, out, numBlocks);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}

+ 24 - 0
pkg/backend/cuda/cuda_dequant_q8k.cu

@@ -0,0 +1,24 @@
+#include "cuda_common.cuh"
+
+// ============================================================
+// Q8_K Dequantization Kernel
+// Each thread handles 1 element within a block
+// ============================================================
+__global__ void dequant_q8k_kernel(const BlockQ8_K* blocks, float* out, int numBlocks) {
+    int blockIdx_q = blockIdx.x;
+    int elemIdx = threadIdx.x; // 0-255
+    
+    if (blockIdx_q >= numBlocks) return;
+    
+    const BlockQ8_K* b = &blocks[blockIdx_q];
+    float d = b->d;
+    
+    out[blockIdx_q * 256 + elemIdx] = d * (float)b->qs[elemIdx];
+}
+
+int cuda_dequant_q8k(const void* blocks, float* out, int numBlocks) {
+    if (numBlocks == 0) return 0;
+    dequant_q8k_kernel<<<numBlocks, 256>>>((const BlockQ8_K*)blocks, out, numBlocks);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}

+ 557 - 0
pkg/backend/cuda/cuda_elementwise.cu

@@ -0,0 +1,557 @@
+#include "cuda_common.cuh"
+
+// --- Kernels ---
+
+__global__ void add_kernel(float* a, const float* b, int n) {
+    int idx = blockIdx.x * blockDim.x + threadIdx.x;
+    if (idx < n) {
+        a[idx] += b[idx];
+    }
+}
+
+int cuda_add_f32(float* a, float* b, size_t n) {
+    int threads = 256;
+    int blocks = (int)((n + threads - 1) / threads);
+    add_kernel<<<blocks, threads>>>(a, b, (int)n);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+__global__ void mul_kernel(float* a, float* b, int n) {
+    int idx = blockIdx.x * blockDim.x + threadIdx.x;
+    if (idx < n) {
+        a[idx] *= b[idx];
+    }
+}
+
+int cuda_mul_f32(float* a, float* b, size_t n) {
+    int threads = 256;
+    int blocks = (n + threads - 1) / threads;
+    mul_kernel<<<blocks, threads>>>(a, b, n);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+// SiLU kernel: x = x * sigmoid(x)
+__global__ void silu_kernel(float* x, int n) {
+    int idx = blockIdx.x * blockDim.x + threadIdx.x;
+    if (idx < n) {
+        float val = x[idx];
+        x[idx] = val / (1.0f + __expf(-val));
+    }
+}
+
+int cuda_silu_f32(float* x, size_t n) {
+    int threads = 256;
+    int blocks = (n + threads - 1) / threads;
+    silu_kernel<<<blocks, threads>>>(x, n);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+// Element-wise multiply in-place
+__global__ void mul_inplace_kernel(float* a, const float* b, int n) {
+    int idx = blockIdx.x * blockDim.x + threadIdx.x;
+    if (idx < n) {
+        a[idx] *= b[idx];
+    }
+}
+
+int cuda_mul_inplace_f32(float* a, const float* b, size_t n) {
+    int threads = 256;
+    int blocks = (n + threads - 1) / threads;
+    mul_inplace_kernel<<<blocks, threads>>>(a, b, n);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+// Copy kernel
+int cuda_copy_f32(float* dst, const float* src, size_t n) {
+    CHECK_CUDA(cudaMemcpy(dst, src, n * sizeof(float), cudaMemcpyDeviceToDevice));
+    return 0;
+}
+
+// ============================================================
+// KDA: Causal short conv1d + SiLU
+// ============================================================
+
+static __device__ __forceinline__ float sigmoid_f32(float x) {
+    return 1.0f / (1.0f + __expf(-x));
+}
+
+static __device__ __forceinline__ float silu_f32(float x) {
+    return x * sigmoid_f32(x);
+}
+
+// xTok: [projSize]
+// state: [projSize, convLen]
+// w: [projSize, kernel] (assumed contiguous)
+__global__ void kda_causal_short_conv1d_token_kernel(
+    float* xTok,
+    float* state,
+    const float* w,
+    int projSize,
+    int kernel,
+    int convLen
+) {
+    int d = blockIdx.x * blockDim.x + threadIdx.x;
+    if (d >= projSize) {
+        return;
+    }
+
+    const int wBase = d * kernel;
+    const int stBase = d * convLen;
+
+    // Read input before overwriting xTok.
+    const float xIn = xTok[d];
+
+    float acc = 0.0f;
+    for (int j = 0; j < convLen; j++) {
+        acc = fmaf(w[wBase + j], state[stBase + j], acc);
+    }
+    acc = fmaf(w[wBase + convLen], xIn, acc);
+
+    xTok[d] = silu_f32(acc);
+
+    // Update causal state: shift left and append xIn.
+    if (convLen > 0) {
+        for (int j = 0; j < convLen - 1; j++) {
+            state[stBase + j] = state[stBase + j + 1];
+        }
+        state[stBase + convLen - 1] = xIn;
+    }
+}
+
+int cuda_kda_causal_short_conv1d_f32(
+    float* x,
+    float* state,
+    const float* w,
+    int tokens,
+    int projSize,
+    int kernel
+) {
+    if (tokens <= 0 || projSize <= 0) {
+        return 0;
+    }
+    if (kernel <= 1) {
+        // Just SiLU.
+        return cuda_silu_f32(x, (size_t)tokens * (size_t)projSize);
+    }
+
+    const int convLen = kernel - 1;
+    int threads = 256;
+    int blocks = (projSize + threads - 1) / threads;
+
+    for (int t = 0; t < tokens; t++) {
+        float* xTok = x + (size_t)t * (size_t)projSize;
+        kda_causal_short_conv1d_token_kernel<<<blocks, threads>>>(xTok, state, w, projSize, kernel, convLen);
+        CHECK_CUDA(cudaGetLastError());
+    }
+    return 0;
+}
+
+// ============================================================
+// KDA: L2 Norm Heads
+// ============================================================
+
+__global__ void kda_l2norm_head_kernel(float* x, int headDim, float eps) {
+    // One block per head segment
+    extern __shared__ float sdata[];
+    
+    int tid = threadIdx.x;
+    float* head = x + blockIdx.x * headDim;
+    
+    // Compute sum of squares
+    float sum = 0.0f;
+    for (int i = tid; i < headDim; i += blockDim.x) {
+        float v = head[i];
+        sum += v * v;
+    }
+    sdata[tid] = sum;
+    __syncthreads();
+    
+    // Reduce
+    for (int s = blockDim.x / 2; s > 0; s >>= 1) {
+        if (tid < s) {
+            sdata[tid] += sdata[tid + s];
+        }
+        __syncthreads();
+    }
+    
+    float invNorm = rsqrtf(sdata[0] + eps);
+    
+    // Normalize
+    for (int i = tid; i < headDim; i += blockDim.x) {
+        head[i] *= invNorm;
+    }
+}
+
+int cuda_l2norm_heads_f32(float* q, float* k, int tokens, int numHeads, int headDim, float eps) {
+    if (tokens <= 0 || numHeads <= 0 || headDim <= 0) return 0;
+    
+    int totalHeads = tokens * numHeads;
+    int threads = min(256, headDim);
+    size_t sharedMem = threads * sizeof(float);
+    
+    kda_l2norm_head_kernel<<<totalHeads, threads, sharedMem>>>(q, headDim, eps);
+    CHECK_CUDA(cudaGetLastError());
+    kda_l2norm_head_kernel<<<totalHeads, threads, sharedMem>>>(k, headDim, eps);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+// ============================================================
+// KDA: Gate computation
+// g_out = -exp(aLog[h]) * softplus(g + dtBias)
+// ============================================================
+
+__device__ __forceinline__ float softplus_f32(float x) {
+    return (x > 20.0f) ? x : logf(1.0f + __expf(x));
+}
+
+__global__ void kda_gate_kernel(
+    const float* g,
+    const float* aLog,
+    const float* dtBias,
+    float* out,
+    int numHeads,
+    int headDim
+) {
+    int idx = blockIdx.x * blockDim.x + threadIdx.x;
+    int projSize = numHeads * headDim;
+    if (idx >= projSize) return;
+    
+    int h = idx / headDim;
+    float mul = -__expf(aLog[h]);
+    float x = g[idx];
+    if (dtBias != nullptr) {
+        x += dtBias[idx];
+    }
+    out[idx] = mul * softplus_f32(x);
+}
+
+int cuda_kda_gate_f32(
+    const float* g,
+    const float* aLog,
+    const float* dtBias,
+    float* out,
+    int tokens,
+    int numHeads,
+    int headDim
+) {
+    if (tokens <= 0) return 0;
+    int projSize = numHeads * headDim;
+    int threads = 256;
+    int blocks = (projSize + threads - 1) / threads;
+    
+    for (int t = 0; t < tokens; t++) {
+        const float* gTok = g + t * projSize;
+        float* outTok = out + t * projSize;
+        kda_gate_kernel<<<blocks, threads>>>(gTok, aLog, dtBias, outTok, numHeads, headDim);
+        CHECK_CUDA(cudaGetLastError());
+    }
+    return 0;
+}
+
+// ============================================================
+// KDA: Recurrent (per-token, per-head)
+// state[h]: [headDim, headDim]
+// ============================================================
+
+__global__ void kda_recurrent_step_kernel(
+    const float* qTok,
+    const float* kTok,
+    float* vTok,
+    const float* gTok,
+    const float* betaTok,
+    float* state,
+    int numHeads,
+    int headDim,
+    float scale
+) {
+    // One block per head (blockIdx.x), threads work on headDim elements.
+    extern __shared__ float shared[];
+    float* tmpKV = shared;
+    float* tmpVM = shared + headDim;
+
+    int h = blockIdx.x;
+    if (h >= numHeads) return;
+
+    int tid = threadIdx.x;
+    int stateStride = headDim * headDim;
+    int off = h * headDim;
+
+    const float* q = qTok + off;
+    const float* k = kTok + off;
+    float* v = vTok + off;
+    const float* g = gTok + off;
+
+    float beta = betaTok[h];
+    float* st = state + h * stateStride;
+
+    // Step 1: Decay state by exp(g)
+    for (int kk = tid; kk < headDim; kk += blockDim.x) {
+        float dec = __expf(g[kk]);
+        for (int vv = 0; vv < headDim; vv++) {
+            st[kk * headDim + vv] *= dec;
+        }
+    }
+    __syncthreads();
+
+    // Step 2: tmpKV = k^T @ state (for each v dimension)
+    for (int vv = tid; vv < headDim; vv += blockDim.x) {
+        float acc = 0.0f;
+        for (int kk = 0; kk < headDim; kk++) {
+            acc += k[kk] * st[kk * headDim + vv];
+        }
+        tmpKV[vv] = acc;
+    }
+    __syncthreads();
+
+    // Step 3: tmpVM = v - tmpKV
+    for (int vv = tid; vv < headDim; vv += blockDim.x) {
+        tmpVM[vv] = v[vv] - tmpKV[vv];
+    }
+    __syncthreads();
+
+    // Step 4: state += beta * k @ tmpVM^T
+    for (int kk = tid; kk < headDim; kk += blockDim.x) {
+        float kj = beta * k[kk];
+        for (int vv = 0; vv < headDim; vv++) {
+            st[kk * headDim + vv] += kj * tmpVM[vv];
+        }
+    }
+    __syncthreads();
+
+    // Step 5: v = (q * scale)^T @ state
+    for (int vv = tid; vv < headDim; vv += blockDim.x) {
+        float acc = 0.0f;
+        for (int kk = 0; kk < headDim; kk++) {
+            acc += (q[kk] * scale) * st[kk * headDim + vv];
+        }
+        v[vv] = acc;
+    }
+}
+
+int cuda_kda_recurrent_f32(
+    const float* q,
+    const float* k,
+    float* v,
+    const float* g,
+    const float* beta,
+    float* state,
+    int tokens,
+    int numHeads,
+    int headDim
+) {
+    if (tokens <= 0 || numHeads <= 0 || headDim <= 0) return 0;
+    
+    int projSize = numHeads * headDim;
+    float scale = 1.0f / sqrtf((float)headDim);
+    
+    int threads = min(256, headDim);
+    size_t sharedMem = 2 * headDim * sizeof(float);
+    
+    for (int t = 0; t < tokens; t++) {
+        const float* qTok = q + t * projSize;
+        const float* kTok = k + t * projSize;
+        float* vTok = v + t * projSize;
+        const float* gTok = g + t * projSize;
+        const float* betaTok = beta + t * numHeads;
+
+        kda_recurrent_step_kernel<<<numHeads, threads, sharedMem>>>(
+            qTok, kTok, vTok, gTok, betaTok, state, numHeads, headDim, scale
+        );
+        CHECK_CUDA(cudaGetLastError());
+    }
+    return 0;
+}
+
+// ============================================================
+// KDA: RMSNorm Gated
+// out = (out / rms) * weight * sigmoid(g)
+// ============================================================
+
+__global__ void kda_rmsnorm_gated_kernel(
+    float* out,
+    const float* g,
+    const float* weight,
+    int headDim,
+    float eps
+) {
+    extern __shared__ float sdata[];
+    
+    int tid = threadIdx.x;
+    float* head = out + blockIdx.x * headDim;
+    const float* gHead = g ? (g + blockIdx.x * headDim) : nullptr;
+    
+    // Compute sum of squares
+    float sum = 0.0f;
+    for (int i = tid; i < headDim; i += blockDim.x) {
+        float v = head[i];
+        sum += v * v;
+    }
+    sdata[tid] = sum;
+    __syncthreads();
+    
+    for (int s = blockDim.x / 2; s > 0; s >>= 1) {
+        if (tid < s) {
+            sdata[tid] += sdata[tid + s];
+        }
+        __syncthreads();
+    }
+    
+    float inv = rsqrtf(sdata[0] / (float)headDim + eps);
+    
+    for (int i = tid; i < headDim; i += blockDim.x) {
+        float y = head[i] * inv * weight[i];
+        if (gHead != nullptr) {
+            y *= 1.0f / (1.0f + __expf(-gHead[i]));  // sigmoid
+        }
+        head[i] = y;
+    }
+}
+
+int cuda_rmsnorm_gated_f32(
+    float* out,
+    const float* g,
+    const float* weight,
+    int n,
+    int headDim,
+    float eps
+) {
+    if (n <= 0 || headDim <= 0) return 0;
+    
+    int numHeads = n / headDim;
+    int threads = min(256, headDim);
+    size_t sharedMem = threads * sizeof(float);
+    
+    kda_rmsnorm_gated_kernel<<<numHeads, threads, sharedMem>>>(out, g, weight, headDim, eps);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+// ============================================================
+// Sigmoid (for MoE router, etc.)
+// ============================================================
+
+__global__ void sigmoid_kernel(float* x, int n) {
+    int idx = blockIdx.x * blockDim.x + threadIdx.x;
+    if (idx < n) {
+        x[idx] = 1.0f / (1.0f + __expf(-x[idx]));
+    }
+}
+
+int cuda_sigmoid_f32(float* x, int n) {
+    if (n <= 0) return 0;
+    int threads = 256;
+    int blocks = (n + threads - 1) / threads;
+    sigmoid_kernel<<<blocks, threads>>>(x, n);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+// ============================================================
+// Softmax per row (for MoE router)
+// ============================================================
+
+__global__ void softmax_row_kernel(float* x, int cols) {
+    extern __shared__ float sdata[];
+    int row = blockIdx.x;
+    int tid = threadIdx.x;
+    float* rowData = x + row * cols;
+    
+    // Find max
+    float maxVal = -1e30f;
+    for (int i = tid; i < cols; i += blockDim.x) {
+        maxVal = fmaxf(maxVal, rowData[i]);
+    }
+    sdata[tid] = maxVal;
+    __syncthreads();
+    for (int s = blockDim.x / 2; s > 0; s >>= 1) {
+        if (tid < s) sdata[tid] = fmaxf(sdata[tid], sdata[tid + s]);
+        __syncthreads();
+    }
+    maxVal = sdata[0];
+    __syncthreads();
+    
+    // Compute exp and sum
+    float sum = 0.0f;
+    for (int i = tid; i < cols; i += blockDim.x) {
+        float v = __expf(rowData[i] - maxVal);
+        rowData[i] = v;
+        sum += v;
+    }
+    sdata[tid] = sum;
+    __syncthreads();
+    for (int s = blockDim.x / 2; s > 0; s >>= 1) {
+        if (tid < s) sdata[tid] += sdata[tid + s];
+        __syncthreads();
+    }
+    float invSum = 1.0f / sdata[0];
+    
+    // Normalize
+    for (int i = tid; i < cols; i += blockDim.x) {
+        rowData[i] *= invSum;
+    }
+}
+
+int cuda_softmax_rows_f32(float* x, int rows, int cols) {
+    if (rows <= 0 || cols <= 0) return 0;
+    int threads = min(256, cols);
+    size_t sharedMem = threads * sizeof(float);
+    softmax_row_kernel<<<rows, threads, sharedMem>>>(x, cols);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+// ============================================================
+// TopK per row (for MoE expert selection)
+// ============================================================
+
+__global__ void topk_per_row_kernel(
+    const float* scores,
+    int* indices,
+    float* values,
+    int cols,
+    int k
+) {
+    int row = blockIdx.x;
+    const float* rowScores = scores + row * cols;
+    int* rowIndices = indices + row * k;
+    float* rowValues = values + row * k;
+    
+    // Simple O(n*k) selection - good enough for small k
+    for (int i = 0; i < k; i++) {
+        float bestVal = -1e30f;
+        int bestIdx = -1;
+        for (int j = 0; j < cols; j++) {
+            float v = rowScores[j];
+            // Check if already selected
+            bool selected = false;
+            for (int p = 0; p < i; p++) {
+                if (rowIndices[p] == j) { selected = true; break; }
+            }
+            if (!selected && v > bestVal) {
+                bestVal = v;
+                bestIdx = j;
+            }
+        }
+        rowIndices[i] = bestIdx;
+        rowValues[i] = bestVal;
+    }
+}
+
+int cuda_topk_per_row_f32(
+    const float* scores,
+    int* indices,
+    float* values,
+    int rows,
+    int cols,
+    int k
+) {
+    if (rows <= 0 || cols <= 0 || k <= 0) return 0;
+    topk_per_row_kernel<<<rows, 1>>>(scores, indices, values, cols, k);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}

+ 254 - 0
pkg/backend/cuda/cuda_kernels_test.go

@@ -0,0 +1,254 @@
+//go:build cuda
+
+package cuda
+
+import (
+	"math"
+	"testing"
+	"unsafe"
+
+	"makarna/pkg/backend/cpu/nn"
+	"makarna/pkg/tensor"
+)
+
+func TestL2NormHeads(t *testing.T) {
+	if !Available() {
+		t.Skip("CUDA not available")
+	}
+
+	tokens, numHeads, headDim := 4, 8, 64
+	n := tokens * numHeads * headDim
+	eps := float32(1e-6)
+
+	// CPU reference
+	qCPU := make([]float32, n)
+	kCPU := make([]float32, n)
+	for i := range qCPU {
+		qCPU[i] = float32(i%100) / 50.0
+		kCPU[i] = float32((i+37)%100) / 50.0
+	}
+	qRef := make([]float32, n)
+	kRef := make([]float32, n)
+	copy(qRef, qCPU)
+	copy(kRef, kCPU)
+	nn.L2NormHeads(qRef, kRef, tokens, numHeads, headDim, eps)
+
+	// GPU
+	qDev, _ := NewTensor(tensor.Shape{tokens, numHeads * headDim}, tensor.Float32, 0)
+	defer qDev.Free()
+	kDev, _ := NewTensor(tensor.Shape{tokens, numHeads * headDim}, tensor.Float32, 0)
+	defer kDev.Free()
+	qDev.CopyFrom(qCPU)
+	kDev.CopyFrom(kCPU)
+
+	if err := L2NormHeads(qDev.Data().(unsafe.Pointer), kDev.Data().(unsafe.Pointer), tokens, numHeads, headDim, eps, 0); err != nil {
+		t.Fatal(err)
+	}
+
+	qOut := make([]float32, n)
+	kOut := make([]float32, n)
+	qDev.CopyToHost(qOut)
+	kDev.CopyToHost(kOut)
+
+	for i := 0; i < n; i++ {
+		if math.Abs(float64(qOut[i]-qRef[i])) > 1e-4 {
+			t.Errorf("Q mismatch at %d: got %f, want %f", i, qOut[i], qRef[i])
+			break
+		}
+		if math.Abs(float64(kOut[i]-kRef[i])) > 1e-4 {
+			t.Errorf("K mismatch at %d: got %f, want %f", i, kOut[i], kRef[i])
+			break
+		}
+	}
+}
+
+func TestSigmoid(t *testing.T) {
+	if !Available() {
+		t.Skip("CUDA not available")
+	}
+
+	n := 1024
+	input := make([]float32, n)
+	for i := range input {
+		input[i] = float32(i-512) / 100.0
+	}
+
+	// CPU reference
+	ref := make([]float32, n)
+	copy(ref, input)
+	nn.SigmoidInplace(ref)
+
+	// GPU
+	dev, _ := NewTensor(tensor.Shape{n}, tensor.Float32, 0)
+	defer dev.Free()
+	dev.CopyFrom(input)
+
+	if err := Sigmoid(dev.Data().(unsafe.Pointer), n, 0); err != nil {
+		t.Fatal(err)
+	}
+
+	out := make([]float32, n)
+	dev.CopyToHost(out)
+
+	for i := 0; i < n; i++ {
+		if math.Abs(float64(out[i]-ref[i])) > 1e-5 {
+			t.Errorf("Sigmoid mismatch at %d: got %f, want %f", i, out[i], ref[i])
+			break
+		}
+	}
+}
+
+func TestSoftmaxRows(t *testing.T) {
+	if !Available() {
+		t.Skip("CUDA not available")
+	}
+
+	rows, cols := 16, 64
+	n := rows * cols
+	input := make([]float32, n)
+	for i := range input {
+		input[i] = float32(i%100) / 50.0
+	}
+
+	// CPU reference (manual softmax per row)
+	ref := make([]float32, n)
+	copy(ref, input)
+	for r := 0; r < rows; r++ {
+		row := ref[r*cols : (r+1)*cols]
+		maxVal := row[0]
+		for _, v := range row {
+			if v > maxVal {
+				maxVal = v
+			}
+		}
+		sum := float32(0)
+		for i := range row {
+			row[i] = float32(math.Exp(float64(row[i] - maxVal)))
+			sum += row[i]
+		}
+		for i := range row {
+			row[i] /= sum
+		}
+	}
+
+	// GPU
+	dev, _ := NewTensor(tensor.Shape{rows, cols}, tensor.Float32, 0)
+	defer dev.Free()
+	dev.CopyFrom(input)
+
+	if err := SoftmaxRows(dev.Data().(unsafe.Pointer), rows, cols, 0); err != nil {
+		t.Fatal(err)
+	}
+
+	out := make([]float32, n)
+	dev.CopyToHost(out)
+
+	for i := 0; i < n; i++ {
+		if math.Abs(float64(out[i]-ref[i])) > 1e-5 {
+			t.Errorf("Softmax mismatch at %d: got %f, want %f", i, out[i], ref[i])
+			break
+		}
+	}
+}
+
+func TestTopKPerRow(t *testing.T) {
+	if !Available() {
+		t.Skip("CUDA not available")
+	}
+
+	rows, cols, k := 4, 16, 3
+	scores := make([]float32, rows*cols)
+	for i := range scores {
+		scores[i] = float32(i % cols)
+	}
+	// Set some specific values
+	scores[0*cols+5] = 100
+	scores[0*cols+10] = 90
+	scores[0*cols+2] = 80
+	scores[1*cols+15] = 50
+	scores[1*cols+0] = 40
+	scores[1*cols+7] = 30
+
+	// GPU
+	scoresDev, _ := NewTensor(tensor.Shape{rows, cols}, tensor.Float32, 0)
+	defer scoresDev.Free()
+	scoresDev.CopyFrom(scores)
+
+	indicesDev, _ := NewTensor(tensor.Shape{rows, k}, tensor.Int32, 0)
+	defer indicesDev.Free()
+	valuesDev, _ := NewTensor(tensor.Shape{rows, k}, tensor.Float32, 0)
+	defer valuesDev.Free()
+
+	if err := TopKPerRow(scoresDev.Data().(unsafe.Pointer), indicesDev.Data().(unsafe.Pointer), valuesDev.Data().(unsafe.Pointer), rows, cols, k, 0); err != nil {
+		t.Fatal(err)
+	}
+
+	indices := make([]int32, rows*k)
+	values := make([]float32, rows*k)
+	indicesDev.CopyToInt32(indices)
+	valuesDev.CopyToHost(values)
+
+	// Check first row: should be indices 5, 10, 2 with values 100, 90, 80
+	if indices[0] != 5 || indices[1] != 10 || indices[2] != 2 {
+		t.Errorf("Row 0 indices: got %v, want [5, 10, 2]", indices[0:3])
+	}
+	if values[0] != 100 || values[1] != 90 || values[2] != 80 {
+		t.Errorf("Row 0 values: got %v, want [100, 90, 80]", values[0:3])
+	}
+
+	// Check second row: should be indices 15, 0, 7 with values 50, 40, 30
+	if indices[3] != 15 || indices[4] != 0 || indices[5] != 7 {
+		t.Errorf("Row 1 indices: got %v, want [15, 0, 7]", indices[3:6])
+	}
+}
+
+func TestRMSNormGated(t *testing.T) {
+	if !Available() {
+		t.Skip("CUDA not available")
+	}
+
+	numHeads, headDim := 8, 64
+	n := numHeads * headDim
+	eps := float32(1e-5)
+
+	out := make([]float32, n)
+	g := make([]float32, n)
+	weight := make([]float32, headDim)
+	for i := range out {
+		out[i] = float32(i%100) / 50.0
+		g[i] = float32((i+13)%100) / 100.0
+	}
+	for i := range weight {
+		weight[i] = 1.0 + float32(i)/float32(headDim)
+	}
+
+	// CPU reference
+	ref := make([]float32, n)
+	copy(ref, out)
+	nn.RMSNormGated(ref, g, weight, headDim, eps)
+
+	// GPU
+	outDev, _ := NewTensor(tensor.Shape{n}, tensor.Float32, 0)
+	defer outDev.Free()
+	outDev.CopyFrom(out)
+	gDev, _ := NewTensor(tensor.Shape{n}, tensor.Float32, 0)
+	defer gDev.Free()
+	gDev.CopyFrom(g)
+	weightDev, _ := NewTensor(tensor.Shape{headDim}, tensor.Float32, 0)
+	defer weightDev.Free()
+	weightDev.CopyFrom(weight)
+
+	if err := RMSNormGated(outDev.Data().(unsafe.Pointer), gDev.Data().(unsafe.Pointer), weightDev.Data().(unsafe.Pointer), n, headDim, eps, 0); err != nil {
+		t.Fatal(err)
+	}
+
+	result := make([]float32, n)
+	outDev.CopyToHost(result)
+
+	for i := 0; i < n; i++ {
+		if math.Abs(float64(result[i]-ref[i])) > 1e-4 {
+			t.Errorf("RMSNormGated mismatch at %d: got %f, want %f", i, result[i], ref[i])
+			break
+		}
+	}
+}

+ 1295 - 0
pkg/backend/cuda/cuda_matmul.cu

@@ -0,0 +1,1295 @@
+#include "cuda_common.cuh"
+#include <cuda_fp16.h>
+
+namespace {
+
+// Simple tiled GEMM kernels for correctness-first dense matmul.
+// These are used as the default dense GEMM path when CUTLASS is not built.
+
+constexpr int TILE = 16;
+
+__global__ void matmul_f32_kernel(const float* __restrict__ A, const float* __restrict__ B, float* __restrict__ C,
+                                  int M, int K, int N) {
+    __shared__ float As[TILE][TILE];
+    __shared__ float Bs[TILE][TILE];
+
+    const int row = blockIdx.y * TILE + threadIdx.y;
+    const int col = blockIdx.x * TILE + threadIdx.x;
+
+    float acc = 0.0f;
+    for (int t = 0; t < (K + TILE - 1) / TILE; ++t) {
+        const int aCol = t * TILE + threadIdx.x;
+        const int bRow = t * TILE + threadIdx.y;
+
+        As[threadIdx.y][threadIdx.x] = (row < M && aCol < K) ? A[row * K + aCol] : 0.0f;
+        Bs[threadIdx.y][threadIdx.x] = (bRow < K && col < N) ? B[bRow * N + col] : 0.0f;
+        __syncthreads();
+
+        #pragma unroll
+        for (int i = 0; i < TILE; ++i) {
+            acc += As[threadIdx.y][i] * Bs[i][threadIdx.x];
+        }
+        __syncthreads();
+    }
+
+    if (row < M && col < N) {
+        C[row * N + col] = acc;
+    }
+}
+
+// Computes C = A @ B^T where B is stored row-major [N, K].
+__global__ void matmul_f32_nt_kernel(const float* __restrict__ A, const float* __restrict__ B, float* __restrict__ C,
+                                     int M, int K, int N) {
+    __shared__ float As[TILE][TILE];
+    __shared__ float Bs[TILE][TILE];
+
+    const int row = blockIdx.y * TILE + threadIdx.y;
+    const int col = blockIdx.x * TILE + threadIdx.x; // maps to n
+
+    float acc = 0.0f;
+    for (int t = 0; t < (K + TILE - 1) / TILE; ++t) {
+        const int aCol = t * TILE + threadIdx.x;
+        const int bCol = t * TILE + threadIdx.y; // k index for B row
+
+        As[threadIdx.y][threadIdx.x] = (row < M && aCol < K) ? A[row * K + aCol] : 0.0f;
+        Bs[threadIdx.y][threadIdx.x] = (col < N && bCol < K) ? B[col * K + bCol] : 0.0f;
+        __syncthreads();
+
+        #pragma unroll
+        for (int i = 0; i < TILE; ++i) {
+            acc += As[threadIdx.y][i] * Bs[i][threadIdx.x];
+        }
+        __syncthreads();
+    }
+
+    if (row < M && col < N) {
+        C[row * N + col] = acc;
+    }
+}
+
+// Computes C = A @ B^T where A and B are stored as IEEE half in uint16.
+__global__ void matmul_f16_nt_kernel(const __half* __restrict__ A, const __half* __restrict__ B, float* __restrict__ C,
+                                     int M, int K, int N) {
+    __shared__ __half As[TILE][TILE];
+    __shared__ __half Bs[TILE][TILE];
+
+    const int row = blockIdx.y * TILE + threadIdx.y;
+    const int col = blockIdx.x * TILE + threadIdx.x;
+
+    float acc = 0.0f;
+    for (int t = 0; t < (K + TILE - 1) / TILE; ++t) {
+        const int aCol = t * TILE + threadIdx.x;
+        const int bCol = t * TILE + threadIdx.y;
+
+        As[threadIdx.y][threadIdx.x] = (row < M && aCol < K) ? A[row * K + aCol] : __float2half(0.0f);
+        Bs[threadIdx.y][threadIdx.x] = (col < N && bCol < K) ? B[col * K + bCol] : __float2half(0.0f);
+        __syncthreads();
+
+        #pragma unroll
+        for (int i = 0; i < TILE; ++i) {
+            acc += __half2float(As[threadIdx.y][i]) * __half2float(Bs[i][threadIdx.x]);
+        }
+        __syncthreads();
+    }
+
+    if (row < M && col < N) {
+        C[row * N + col] = acc;
+    }
+}
+
+} // namespace
+
+__global__ void matmul_q5k_kernel(float* A, const BlockQ5_K* B, float* C,
+                                  int M, int K, int N, int blocksPerRow) {
+    const int row = blockIdx.y;
+    const int warp = threadIdx.y;
+    const int lane = threadIdx.x;
+    const int col = blockIdx.x * 8 + warp;
+
+    // row is uniform across the block, so an early return here is safe.
+    if (row >= M) return;
+    // col is warp-specific. Do NOT early-return on col>=N because we use __syncthreads().
+    const bool colIn = (col < N);
+
+    float sum = 0.0f;
+
+    // Cache the A tile (256 floats) once per block so the 8 warps (8 columns) reuse it.
+    __shared__ float a_sh[256];
+
+    __shared__ unsigned char sc_sh[8][8];
+    __shared__ unsigned char m_sh[8][8];
+    __shared__ float ds_sh[8][8];
+    __shared__ float dm_sh[8][8];
+    __shared__ float d_sh[8];
+    __shared__ float dmin_sh[8];
+
+    for (int blk = 0; blk < blocksPerRow; blk++) {
+        // Cache A tile once per block (256 floats). Each thread loads one element.
+        const int tid = warp * 32 + lane;
+        const float* aRow = A + row * K + blk * 256;
+        a_sh[tid] = aRow[tid];
+
+        if (colIn) {
+            const BlockQ5_K* b = &B[col * blocksPerRow + blk];
+            if (lane < 8) {
+                unsigned char sc;
+                unsigned char mn;
+                if (lane < 4) {
+                    sc = b->scales[lane] & 63;
+                    mn = b->scales[lane + 4] & 63;
+                } else {
+                    const int j = lane;
+                    sc = (b->scales[j + 4] & 0xF) | ((b->scales[j - 4] >> 6) << 4);
+                    mn = (b->scales[j + 4] >> 4) | ((b->scales[j] >> 6) << 4);
+                }
+                sc_sh[warp][lane] = sc;
+                m_sh[warp][lane] = mn;
+            }
+            if (lane == 0) {
+                d_sh[warp] = fp16_to_fp32(b->d);
+                dmin_sh[warp] = fp16_to_fp32(b->dmin);
+            }
+        }
+
+        // Ensure all warps have finished loading this block's a_sh before use,
+        // and that no warp overwrites a_sh while others are still reading it.
+        __syncthreads();
+
+        if (colIn) {
+            const BlockQ5_K* b = &B[col * blocksPerRow + blk];
+
+            // Precompute per-group multipliers once (one lane per group).
+            if (lane < 8) {
+                const float d = d_sh[warp];
+                const float dmin = dmin_sh[warp];
+                const unsigned char sc = sc_sh[warp][lane];
+                const unsigned char mn = m_sh[warp][lane];
+                ds_sh[warp][lane] = d * (float)sc;
+                dm_sh[warp][lane] = dmin * (float)mn;
+            }
+            __syncwarp();
+
+            const unsigned char hb = b->qh[lane];
+
+            #pragma unroll
+            for (int p = 0; p < 4; p++) {
+                const unsigned char qs = b->qs[p * 32 + lane];
+                int q0 = qs & 0xF;
+                int q1 = qs >> 4;
+                q0 += ((hb >> (2 * p)) & 1) << 4;
+                q1 += ((hb >> (2 * p + 1)) & 1) << 4;
+
+                const int idx0 = p * 64 + lane;
+                const int idx1 = idx0 + 32;
+
+                const int g0 = 2 * p;
+                const int g1 = g0 + 1;
+                const float ds0 = ds_sh[warp][g0];
+                const float dm0 = dm_sh[warp][g0];
+                const float ds1 = ds_sh[warp][g1];
+                const float dm1 = dm_sh[warp][g1];
+
+                sum += a_sh[idx0] * ((float)q0 * ds0 - dm0);
+                sum += a_sh[idx1] * ((float)q1 * ds1 - dm1);
+            }
+        }
+
+        __syncthreads();
+    }
+
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        sum += __shfl_down_sync(0xffffffff, sum, offset);
+    }
+    if (colIn && lane == 0) {
+        C[row * N + col] = sum;
+    }
+}
+
+int cuda_matmul_f32_q5k(float* A, const void* B, float* C, int M, int K, int N) {
+    int blocksPerRow = K / 256;
+    dim3 threads(32, 8);
+    dim3 blocks((N + 7) / 8, M);
+
+    matmul_q5k_kernel<<<blocks, threads>>>(A, (const BlockQ5_K*)B, C, M, K, N, blocksPerRow);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+
+int cuda_matmul_f32(float* A, float* B, float* C, int M, int K, int N) {
+    if (M <= 0 || N <= 0 || K <= 0) return 0;
+    dim3 threads(TILE, TILE);
+    dim3 blocks((N + TILE - 1) / TILE, (M + TILE - 1) / TILE);
+    matmul_f32_kernel<<<blocks, threads>>>(A, B, C, M, K, N);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+int cuda_matmul_f32_nt(float* A, float* B, float* C, int M, int K, int N) {
+    if (M <= 0 || N <= 0 || K <= 0) return 0;
+    dim3 threads(TILE, TILE);
+    dim3 blocks((N + TILE - 1) / TILE, (M + TILE - 1) / TILE);
+    matmul_f32_nt_kernel<<<blocks, threads>>>(A, B, C, M, K, N);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+int cuda_matmul_f16_nt(const unsigned short* A, const unsigned short* B, float* C, int M, int K, int N) {
+    if (M <= 0 || N <= 0 || K <= 0) return 0;
+    dim3 threads(TILE, TILE);
+    dim3 blocks((N + TILE - 1) / TILE, (M + TILE - 1) / TILE);
+    matmul_f16_nt_kernel<<<blocks, threads>>>(reinterpret_cast<const __half*>(A), reinterpret_cast<const __half*>(B), C, M, K, N);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+// ============================================================
+// Fused Q8_K MatMul Kernel (tiled)
+// C[m,n] = sum_k A[m,k] * dequant(B[n,k])
+// Uses shared memory tiles to reduce global memory pressure.
+// ============================================================
+__global__ void matmul_q8k_kernel(float* A, const BlockQ8_K* B, float* C,
+                                  int M, int K, int N, int blocksPerRow) {
+    const int row = blockIdx.y;
+    const int warp = threadIdx.y;
+    const int lane = threadIdx.x;
+    const int col = blockIdx.x * 8 + warp;
+
+    if (row >= M) return;
+    const bool colIn = (col < N);
+
+    float sum = 0.0f;
+
+    __shared__ float a_sh[256];
+    __shared__ float d_sh[8];
+
+    for (int blk = 0; blk < blocksPerRow; blk++) {
+        const int tid = warp * 32 + lane;
+        const float* aRow = A + row * K + blk * 256;
+        a_sh[tid] = aRow[tid];
+
+        if (colIn && lane == 0) {
+            d_sh[warp] = B[col * blocksPerRow + blk].d;
+        }
+
+        __syncthreads();
+
+        if (colIn) {
+            const BlockQ8_K* b = &B[col * blocksPerRow + blk];
+            const float d = d_sh[warp];
+
+            // Each lane handles 8 weights in the 256-wide block.
+            #pragma unroll
+            for (int i = 0; i < 8; i++) {
+                const int idx = lane + (i * 32); // 0..255
+                const float w = d * (float)((int)b->qs[idx]);
+                sum += a_sh[idx] * w;
+            }
+        }
+
+        __syncthreads();
+    }
+
+    // Warp reduction
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        sum += __shfl_down_sync(0xffffffff, sum, offset);
+    }
+    if (colIn && lane == 0) {
+        C[row * N + col] = sum;
+    }
+}
+
+int cuda_matmul_f32_q8k(float* A, const void* B, float* C, int M, int K, int N) {
+    int blocksPerRow = K / 256;
+    dim3 threads(32, 8);
+    dim3 blocks((N + 7) / 8, M);
+    
+    matmul_q8k_kernel<<<blocks, threads>>>(A, (const BlockQ8_K*)B, C, M, K, N, blocksPerRow);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+int cuda_matmul_f32_q8k_timed(float* A, const void* B, float* C, int M, int K, int N, float* ms) {
+    cudaEvent_t evStart;
+    cudaEvent_t evStop;
+    CHECK_CUDA(cudaEventCreate(&evStart));
+    CHECK_CUDA(cudaEventCreate(&evStop));
+
+    int blocksPerRow = K / 256;
+    dim3 threads(32, 8);
+    dim3 blocks((N + 7) / 8, M);
+
+    CHECK_CUDA(cudaEventRecord(evStart));
+    matmul_q8k_kernel<<<blocks, threads>>>(A, (const BlockQ8_K*)B, C, M, K, N, blocksPerRow);
+    CHECK_CUDA(cudaEventRecord(evStop));
+    CHECK_CUDA(cudaEventSynchronize(evStop));
+
+    float elapsed = 0.0f;
+    CHECK_CUDA(cudaEventElapsedTime(&elapsed, evStart, evStop));
+    if (ms != NULL) {
+        *ms = elapsed;
+    }
+
+    CHECK_CUDA(cudaEventDestroy(evStart));
+    CHECK_CUDA(cudaEventDestroy(evStop));
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+// ============================================================
+// FP16 Input Variants - 2x memory bandwidth for activations
+// Input A is FP16, dequantized weights computed in FP32,
+// accumulation in FP32, output FP32.
+// ============================================================
+
+__global__ void matmul_q8k_kernel_f16in(const __half* A, const BlockQ8_K* B, float* C,
+                                         int M, int K, int N, int blocksPerRow) {
+    const int row = blockIdx.y;
+    const int warp = threadIdx.y;
+    const int lane = threadIdx.x;
+    const int col = blockIdx.x * 8 + warp;
+
+    if (row >= M) return;
+    const bool colIn = (col < N);
+
+    float sum = 0.0f;
+
+    __shared__ float a_sh[256];
+    __shared__ float d_sh[8];
+
+    for (int blk = 0; blk < blocksPerRow; blk++) {
+        const int tid = warp * 32 + lane;
+        const __half* aRow = A + row * K + blk * 256;
+        // Load FP16, convert to FP32 in shared memory
+        a_sh[tid] = __half2float(aRow[tid]);
+
+        if (colIn && lane == 0) {
+            d_sh[warp] = B[col * blocksPerRow + blk].d;
+        }
+
+        __syncthreads();
+
+        if (colIn) {
+            const BlockQ8_K* b = &B[col * blocksPerRow + blk];
+            const float d = d_sh[warp];
+
+            #pragma unroll
+            for (int i = 0; i < 8; i++) {
+                const int idx = lane + (i * 32);
+                const float w = d * (float)((int)b->qs[idx]);
+                sum += a_sh[idx] * w;
+            }
+        }
+
+        __syncthreads();
+    }
+
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        sum += __shfl_down_sync(0xffffffff, sum, offset);
+    }
+    if (colIn && lane == 0) {
+        C[row * N + col] = sum;
+    }
+}
+
+int cuda_matmul_f16_q8k(const void* A, const void* B, float* C, int M, int K, int N) {
+    int blocksPerRow = K / 256;
+    dim3 threads(32, 8);
+    dim3 blocks((N + 7) / 8, M);
+    
+    matmul_q8k_kernel_f16in<<<blocks, threads>>>((const __half*)A, (const BlockQ8_K*)B, C, M, K, N, blocksPerRow);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+// ============================================================
+// Fused Q4_K MatMul Kernel - simplified version
+// For full performance, would need shared memory tiling
+// ============================================================
+__global__ void matmul_q4k_kernel(float* A, const BlockQ4_K* B, float* C,
+                                  int M, int K, int N, int blocksPerRow) {
+    const int row = blockIdx.y;
+    const int warp = threadIdx.y;
+    const int lane = threadIdx.x;
+    const int col = blockIdx.x * 8 + warp;
+
+    if (row >= M) return;
+    const bool colIn = (col < N);
+
+    float sum = 0.0f;
+
+    __shared__ float a_sh[256];
+    __shared__ unsigned char sc_sh[8][8];
+    __shared__ unsigned char m_sh[8][8];
+    __shared__ float ds_sh[8][8];
+    __shared__ float dm_sh[8][8];
+    __shared__ float d_sh[8];
+    __shared__ float dmin_sh[8];
+
+    for (int blk = 0; blk < blocksPerRow; blk++) {
+        const int tid = warp * 32 + lane;
+        const float* aRow = A + row * K + blk * 256;
+        a_sh[tid] = aRow[tid];
+
+        if (colIn) {
+            const BlockQ4_K* b = &B[col * blocksPerRow + blk];
+
+            // Parallel unpack scale/min for groups 0..7 (one lane per group).
+            if (lane < 8) {
+                unsigned char sc;
+                unsigned char mn;
+                if (lane < 4) {
+                    sc = b->scales[lane] & 63;
+                    mn = b->scales[lane + 4] & 63;
+                } else {
+                    const int j = lane;
+                    sc = (b->scales[j + 4] & 0xF) | ((b->scales[j - 4] >> 6) << 4);
+                    mn = (b->scales[j + 4] >> 4) | ((b->scales[j] >> 6) << 4);
+                }
+                sc_sh[warp][lane] = sc;
+                m_sh[warp][lane] = mn;
+            }
+            if (lane == 0) {
+                d_sh[warp] = fp16_to_fp32(b->d);
+                dmin_sh[warp] = fp16_to_fp32(b->dmin);
+            }
+        }
+
+        __syncthreads();
+
+        if (colIn) {
+            const BlockQ4_K* b = &B[col * blocksPerRow + blk];
+
+            // Precompute per-group float multipliers once.
+            if (lane < 8) {
+                const float d = d_sh[warp];
+                const float dmin = dmin_sh[warp];
+                const unsigned char sc = sc_sh[warp][lane];
+                const unsigned char mn = m_sh[warp][lane];
+                ds_sh[warp][lane] = d * (float)sc;
+                dm_sh[warp][lane] = dmin * (float)mn;
+            }
+            __syncwarp();
+
+            const float ds0 = ds_sh[warp][0];
+            const float dm0 = dm_sh[warp][0];
+            const float ds1 = ds_sh[warp][1];
+            const float dm1 = dm_sh[warp][1];
+            const float ds2 = ds_sh[warp][2];
+            const float dm2 = dm_sh[warp][2];
+            const float ds3 = ds_sh[warp][3];
+            const float dm3 = dm_sh[warp][3];
+            const float ds4 = ds_sh[warp][4];
+            const float dm4 = dm_sh[warp][4];
+            const float ds5 = ds_sh[warp][5];
+            const float dm5 = dm_sh[warp][5];
+            const float ds6 = ds_sh[warp][6];
+            const float dm6 = dm_sh[warp][6];
+            const float ds7 = ds_sh[warp][7];
+            const float dm7 = dm_sh[warp][7];
+
+            // Each lane processes 4 bytes; each byte contains 2 nibbles => 8 values per lane.
+            // This halves qs loads and reduces bit ops.
+            #pragma unroll
+            for (int p = 0; p < 4; p++) {
+                const unsigned char qs = b->qs[p * 32 + lane];
+                const int q0 = qs & 0xF;
+                const int q1 = qs >> 4;
+
+                const int idx0 = p * 64 + lane;      // group = 2*p
+                const int idx1 = idx0 + 32;          // group = 2*p + 1
+
+                float dsA, dmA, dsB, dmB;
+                if (p == 0) {
+                    dsA = ds0; dmA = dm0; dsB = ds1; dmB = dm1;
+                } else if (p == 1) {
+                    dsA = ds2; dmA = dm2; dsB = ds3; dmB = dm3;
+                } else if (p == 2) {
+                    dsA = ds4; dmA = dm4; dsB = ds5; dmB = dm5;
+                } else {
+                    dsA = ds6; dmA = dm6; dsB = ds7; dmB = dm7;
+                }
+
+                const float w0 = (float)q0 * dsA - dmA;
+                const float w1 = (float)q1 * dsB - dmB;
+
+                sum += a_sh[idx0] * w0;
+                sum += a_sh[idx1] * w1;
+            }
+        }
+
+        __syncthreads();
+    }
+
+    // Warp reduction
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        sum += __shfl_down_sync(0xffffffff, sum, offset);
+    }
+    if (colIn && lane == 0) {
+        C[row * N + col] = sum;
+    }
+}
+
+int cuda_matmul_f32_q4k(float* A, const void* B, float* C, int M, int K, int N) {
+    int blocksPerRow = K / 256;
+    dim3 threads(32, 8);
+    dim3 blocks((N + 7) / 8, M);
+    
+    matmul_q4k_kernel<<<blocks, threads>>>(A, (const BlockQ4_K*)B, C, M, K, N, blocksPerRow);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+int cuda_matmul_f32_q4k_timed(float* A, const void* B, float* C, int M, int K, int N, float* ms) {
+    cudaEvent_t evStart;
+    cudaEvent_t evStop;
+    CHECK_CUDA(cudaEventCreate(&evStart));
+    CHECK_CUDA(cudaEventCreate(&evStop));
+
+    int blocksPerRow = K / 256;
+    dim3 threads(32, 8);
+    dim3 blocks((N + 7) / 8, M);
+
+    CHECK_CUDA(cudaEventRecord(evStart));
+    matmul_q4k_kernel<<<blocks, threads>>>(A, (const BlockQ4_K*)B, C, M, K, N, blocksPerRow);
+    CHECK_CUDA(cudaEventRecord(evStop));
+    CHECK_CUDA(cudaEventSynchronize(evStop));
+
+    float elapsed = 0.0f;
+    CHECK_CUDA(cudaEventElapsedTime(&elapsed, evStart, evStop));
+    if (ms != NULL) {
+        *ms = elapsed;
+    }
+
+    CHECK_CUDA(cudaEventDestroy(evStart));
+    CHECK_CUDA(cudaEventDestroy(evStop));
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+// ============================================================
+// Fused Q2_K MatMul Kernel - Naive
+// ============================================================
+__global__ void matmul_q2k_kernel(float* A, const BlockQ2_K* B, float* C,
+                                  int M, int K, int N, int blocksPerRow) {
+    const int row = blockIdx.y;
+    const int warp = threadIdx.y;
+    const int lane = threadIdx.x;
+    const int col = blockIdx.x * 8 + warp;
+
+    if (row >= M) return;
+    const bool colIn = (col < N);
+
+    float sum = 0.0f;
+
+    __shared__ float a_sh[256];
+
+    __shared__ float d_sh[8];
+    __shared__ float dmin_sh[8];
+    __shared__ unsigned char scales_sh[8][16];
+    __shared__ unsigned char qs_sh[8][64];
+
+    for (int blk = 0; blk < blocksPerRow; blk++) {
+        // Cache A tile once per block (256 floats) to avoid redundant global loads.
+        // Each thread loads one element: tid in [0,255].
+        const int tid = warp * 32 + lane;
+        const float* aRow = A + row * K + blk * 256;
+        a_sh[tid] = aRow[tid];
+
+        if (colIn) {
+            const BlockQ2_K* b = &B[col * blocksPerRow + blk];
+
+            // Cooperative per-warp cache.
+            if (lane == 0) {
+                d_sh[warp] = fp16_to_fp32(b->d);
+                dmin_sh[warp] = fp16_to_fp32(b->dmin);
+            }
+            if (lane < 16) {
+                scales_sh[warp][lane] = b->scales[lane];
+            }
+            // Load 64 bytes qs with 32 lanes.
+            qs_sh[warp][lane] = b->qs[lane];
+            qs_sh[warp][lane + 32] = b->qs[lane + 32];
+        }
+
+        __syncthreads();
+
+        if (colIn) {
+            const float d = d_sh[warp];
+            const float dmin = dmin_sh[warp];
+
+            // Each lane handles 8 values.
+            #pragma unroll
+            for (int i = 0; i < 8; i++) {
+                const int idx = lane + (i * 32); // 0..255
+                const int is = idx >> 5;         // 0..7
+                const int iq = idx & 31;         // 0..31
+
+                const int qsIdx = (is >> 2) * 32 + iq;
+                const int shift = (is & 3) * 2;
+                const int val = (qs_sh[warp][qsIdx] >> shift) & 3;
+
+                const int scIdx = (is >> 2) * 8 + (is & 3) * 2 + (iq >> 4);
+                const unsigned char sc = scales_sh[warp][scIdx];
+                const float dl = d * (float)(sc & 0xF);
+                const float ml = dmin * (float)(sc >> 4);
+
+                const float w = dl * (float)val - ml;
+                sum += a_sh[idx] * w;
+            }
+        }
+
+        // Ensure all warps finished reading this block's a_sh before it is overwritten.
+        __syncthreads();
+    }
+
+    // Warp reduction
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        sum += __shfl_down_sync(0xffffffff, sum, offset);
+    }
+    if (colIn && lane == 0) {
+        C[row * N + col] = sum;
+    }
+}
+
+int cuda_matmul_f32_q2k(float* A, const void* B, float* C, int M, int K, int N) {
+    int blocksPerRow = K / 256;
+    dim3 threads(32, 8);
+    dim3 blocks((N + 7) / 8, M);
+    
+    matmul_q2k_kernel<<<blocks, threads>>>(A, (const BlockQ2_K*)B, C, M, K, N, blocksPerRow);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+// ============================================================
+// Fused Q3_K MatMul Kernel
+// ============================================================
+__global__ void matmul_q3k_kernel(float* A, const BlockQ3_K* B, float* C,
+                                  int M, int K, int N, int blocksPerRow) {
+    const int row = blockIdx.y;
+    const int warp = threadIdx.y;
+    const int lane = threadIdx.x;
+    const int col = blockIdx.x * 8 + warp;
+
+    if (row >= M) return;
+    const bool colIn = (col < N);
+
+    float sum = 0.0f;
+
+    __shared__ float a_sh[256];
+
+    __shared__ float d_sh[8];
+    __shared__ unsigned char scales_sh[8][12];
+    __shared__ unsigned char qs_sh[8][64];
+    __shared__ unsigned char hmask_sh[8][32];
+
+    for (int blk = 0; blk < blocksPerRow; blk++) {
+        // Cache A tile once per block (256 floats). Each thread loads one element.
+        const int tid = warp * 32 + lane;
+        const float* aRow = A + row * K + blk * 256;
+        a_sh[tid] = aRow[tid];
+
+        if (colIn) {
+            const BlockQ3_K* b = &B[col * blocksPerRow + blk];
+
+            // Cache quant block bytes.
+            if (lane == 0) {
+                d_sh[warp] = fp16_to_fp32(b->d);
+            }
+            if (lane < 12) {
+                scales_sh[warp][lane] = b->scales[lane];
+            }
+            // qs: 64 bytes
+            qs_sh[warp][lane] = b->qs[lane];
+            qs_sh[warp][lane + 32] = b->qs[lane + 32];
+            // hmask: 32 bytes
+            hmask_sh[warp][lane] = b->hmask[lane];
+        }
+
+        __syncthreads();
+
+        if (colIn) {
+            const float d = d_sh[warp];
+
+            // Each lane handles 8 elements.
+            #pragma unroll
+            for (int i = 0; i < 8; i++) {
+                const int idx = lane + (i * 32); // 0..255
+                const int is = idx >> 5;         // 0..7
+                const int iq = idx & 31;         // 0..31
+
+                const int qsIdx = (is >> 2) * 32 + iq;
+                const int shift = (is & 3) * 2;
+                int qv = (qs_sh[warp][qsIdx] >> shift) & 0x3;
+
+                const unsigned char m = (unsigned char)(1 << ((is >> 2) * 4 + (is & 3)));
+                if ((hmask_sh[warp][iq] & m) == 0) {
+                    qv -= 4;
+                }
+
+                const int sIdx = (is >> 2) * 8 + (is & 3) * 2 + (iq >> 4); // 0..15
+                unsigned char sc;
+                if (sIdx < 8) {
+                    sc = scales_sh[warp][sIdx] & 0xF;
+                } else {
+                    sc = scales_sh[warp][sIdx - 8] >> 4;
+                }
+                sc |= ((scales_sh[warp][8 + (sIdx & 3)] >> (2 * (sIdx >> 2))) & 0x3) << 4;
+                const float scale = (float)((int)((signed char)sc) - 32);
+
+                const float w = d * scale * (float)qv;
+                sum += a_sh[idx] * w;
+            }
+        }
+
+        __syncthreads();
+    }
+
+    // Warp reduction
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        sum += __shfl_down_sync(0xffffffff, sum, offset);
+    }
+    if (colIn && lane == 0) {
+        C[row * N + col] = sum;
+    }
+}
+
+int cuda_matmul_f32_q3k(float* A, const void* B, float* C, int M, int K, int N) {
+    int blocksPerRow = K / 256;
+    dim3 threads(32, 8);
+    dim3 blocks((N + 7) / 8, M);
+    matmul_q3k_kernel<<<blocks, threads>>>(A, (const BlockQ3_K*)B, C, M, K, N, blocksPerRow);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+// ============================================================
+// Fused Q6_K MatMul Kernel
+// ============================================================
+__global__ void matmul_q6k_kernel(float* A, const BlockQ6_K* B, float* C,
+                                  int M, int K, int N, int blocksPerRow) {
+    const int row = blockIdx.y;
+    const int warp = threadIdx.y;
+    const int lane = threadIdx.x;
+    const int col = blockIdx.x * 8 + warp;
+
+    if (row >= M) return;
+    const bool colIn = (col < N);
+
+    float sum = 0.0f;
+
+    __shared__ float a_sh[256];
+
+    __shared__ float d_sh[8];
+    __shared__ signed char scales_sh[8][16];
+    __shared__ unsigned char ql_sh[8][128];
+    __shared__ unsigned char qh_sh[8][64];
+
+    for (int blk = 0; blk < blocksPerRow; blk++) {
+        // Cache A tile once per block.
+        const int tid = warp * 32 + lane;
+        const float* aRow = A + row * K + blk * 256;
+        a_sh[tid] = aRow[tid];
+
+        if (colIn) {
+            const BlockQ6_K* b = &B[col * blocksPerRow + blk];
+
+            // Cache quant block bytes.
+            if (lane == 0) {
+                d_sh[warp] = fp16_to_fp32(b->d);
+            }
+            if (lane < 16) {
+                scales_sh[warp][lane] = b->scales[lane];
+            }
+
+            // qh: 64 bytes
+            qh_sh[warp][lane] = b->qh[lane];
+            qh_sh[warp][lane + 32] = b->qh[lane + 32];
+
+            // ql: 128 bytes
+            ql_sh[warp][lane] = b->ql[lane];
+            ql_sh[warp][lane + 32] = b->ql[lane + 32];
+            ql_sh[warp][lane + 64] = b->ql[lane + 64];
+            ql_sh[warp][lane + 96] = b->ql[lane + 96];
+        }
+
+        __syncthreads();
+
+        if (colIn) {
+            const float d = d_sh[warp];
+
+            // Each lane handles 8 elements.
+            #pragma unroll
+            for (int i = 0; i < 8; i++) {
+                const int idx = lane + (i * 32); // 0..255
+                const int is = idx >> 5;         // 0..7
+                const int iq = idx & 31;         // 0..31
+
+                const int qlIdx = (is >> 2) * 64 + (is & 1) * 32 + iq;
+                const int qhIdx = (is >> 2) * 32 + iq;
+                const int scIdx = (is >> 2) * 8 + (is & 3) * 2 + (iq >> 4);
+
+                const unsigned char ql = ql_sh[warp][qlIdx];
+                const unsigned char qh = qh_sh[warp][qhIdx];
+
+                const int shift_ql = ((is & 3) < 2) ? 0 : 4;
+                const int shift_qh = (is & 3) * 2;
+
+                int q = ((ql >> shift_ql) & 0xF) | (((qh >> shift_qh) & 3) << 4);
+                q -= 32;
+
+                const float w = d * (float)scales_sh[warp][scIdx] * (float)q;
+                sum += a_sh[idx] * w;
+            }
+        }
+
+        __syncthreads();
+    }
+
+    // Warp reduction
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        sum += __shfl_down_sync(0xffffffff, sum, offset);
+    }
+    if (colIn && lane == 0) {
+        C[row * N + col] = sum;
+    }
+}
+
+int cuda_matmul_f32_q6k(float* A, const void* B, float* C, int M, int K, int N) {
+    int blocksPerRow = K / 256;
+    dim3 threads(32, 8);
+    dim3 blocks((N + 7) / 8, M);
+    matmul_q6k_kernel<<<blocks, threads>>>(A, (const BlockQ6_K*)B, C, M, K, N, blocksPerRow);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+// ============================================================
+// FP16 Input Variants for Q4K, Q5K, Q2K, Q3K, Q6K
+// Same logic as FP32 versions but load A as FP16
+// ============================================================
+
+__global__ void matmul_q4k_kernel_f16in(const __half* A, const BlockQ4_K* B, float* C,
+                                         int M, int K, int N, int blocksPerRow) {
+    const int row = blockIdx.y;
+    const int warp = threadIdx.y;
+    const int lane = threadIdx.x;
+    const int col = blockIdx.x * 8 + warp;
+
+    if (row >= M) return;
+    const bool colIn = (col < N);
+
+    float sum = 0.0f;
+
+    __shared__ float a_sh[256];
+    __shared__ unsigned char sc_sh[8][8];
+    __shared__ unsigned char m_sh[8][8];
+    __shared__ float ds_sh[8][8];
+    __shared__ float dm_sh[8][8];
+    __shared__ float d_sh[8];
+    __shared__ float dmin_sh[8];
+
+    for (int blk = 0; blk < blocksPerRow; blk++) {
+        const int tid = warp * 32 + lane;
+        const __half* aRow = A + row * K + blk * 256;
+        a_sh[tid] = __half2float(aRow[tid]);
+
+        if (colIn) {
+            const BlockQ4_K* b = &B[col * blocksPerRow + blk];
+            if (lane < 8) {
+                unsigned char sc, mn;
+                if (lane < 4) {
+                    sc = b->scales[lane] & 63;
+                    mn = b->scales[lane + 4] & 63;
+                } else {
+                    const int j = lane;
+                    sc = (b->scales[j + 4] & 0xF) | ((b->scales[j - 4] >> 6) << 4);
+                    mn = (b->scales[j + 4] >> 4) | ((b->scales[j] >> 6) << 4);
+                }
+                sc_sh[warp][lane] = sc;
+                m_sh[warp][lane] = mn;
+            }
+            if (lane == 0) {
+                d_sh[warp] = fp16_to_fp32(b->d);
+                dmin_sh[warp] = fp16_to_fp32(b->dmin);
+            }
+        }
+
+        __syncthreads();
+
+        if (colIn) {
+            const BlockQ4_K* b = &B[col * blocksPerRow + blk];
+            if (lane < 8) {
+                ds_sh[warp][lane] = d_sh[warp] * (float)sc_sh[warp][lane];
+                dm_sh[warp][lane] = dmin_sh[warp] * (float)m_sh[warp][lane];
+            }
+            __syncwarp();
+
+            #pragma unroll
+            for (int p = 0; p < 4; p++) {
+                const unsigned char qs = b->qs[p * 32 + lane];
+                const int q0 = qs & 0xF;
+                const int q1 = qs >> 4;
+                const int idx0 = p * 64 + lane;
+                const int idx1 = idx0 + 32;
+                const int g0 = 2 * p;
+                const int g1 = g0 + 1;
+                sum += a_sh[idx0] * ((float)q0 * ds_sh[warp][g0] - dm_sh[warp][g0]);
+                sum += a_sh[idx1] * ((float)q1 * ds_sh[warp][g1] - dm_sh[warp][g1]);
+            }
+        }
+        __syncthreads();
+    }
+
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        sum += __shfl_down_sync(0xffffffff, sum, offset);
+    }
+    if (colIn && lane == 0) {
+        C[row * N + col] = sum;
+    }
+}
+
+int cuda_matmul_f16_q4k(const void* A, const void* B, float* C, int M, int K, int N) {
+    int blocksPerRow = K / 256;
+    dim3 threads(32, 8);
+    dim3 blocks((N + 7) / 8, M);
+    matmul_q4k_kernel_f16in<<<blocks, threads>>>((const __half*)A, (const BlockQ4_K*)B, C, M, K, N, blocksPerRow);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+__global__ void matmul_q5k_kernel_f16in(const __half* A, const BlockQ5_K* B, float* C,
+                                         int M, int K, int N, int blocksPerRow) {
+    const int row = blockIdx.y;
+    const int warp = threadIdx.y;
+    const int lane = threadIdx.x;
+    const int col = blockIdx.x * 8 + warp;
+
+    if (row >= M) return;
+    const bool colIn = (col < N);
+
+    float sum = 0.0f;
+
+    __shared__ float a_sh[256];
+    __shared__ unsigned char sc_sh[8][8];
+    __shared__ unsigned char m_sh[8][8];
+    __shared__ float ds_sh[8][8];
+    __shared__ float dm_sh[8][8];
+    __shared__ float d_sh[8];
+    __shared__ float dmin_sh[8];
+
+    for (int blk = 0; blk < blocksPerRow; blk++) {
+        const int tid = warp * 32 + lane;
+        const __half* aRow = A + row * K + blk * 256;
+        a_sh[tid] = __half2float(aRow[tid]);
+
+        if (colIn) {
+            const BlockQ5_K* b = &B[col * blocksPerRow + blk];
+            if (lane < 8) {
+                unsigned char sc, mn;
+                if (lane < 4) {
+                    sc = b->scales[lane] & 63;
+                    mn = b->scales[lane + 4] & 63;
+                } else {
+                    const int j = lane;
+                    sc = (b->scales[j + 4] & 0xF) | ((b->scales[j - 4] >> 6) << 4);
+                    mn = (b->scales[j + 4] >> 4) | ((b->scales[j] >> 6) << 4);
+                }
+                sc_sh[warp][lane] = sc;
+                m_sh[warp][lane] = mn;
+            }
+            if (lane == 0) {
+                d_sh[warp] = fp16_to_fp32(b->d);
+                dmin_sh[warp] = fp16_to_fp32(b->dmin);
+            }
+        }
+
+        __syncthreads();
+
+        if (colIn) {
+            const BlockQ5_K* b = &B[col * blocksPerRow + blk];
+            if (lane < 8) {
+                ds_sh[warp][lane] = d_sh[warp] * (float)sc_sh[warp][lane];
+                dm_sh[warp][lane] = dmin_sh[warp] * (float)m_sh[warp][lane];
+            }
+            __syncwarp();
+
+            const unsigned char hb = b->qh[lane];
+            #pragma unroll
+            for (int p = 0; p < 4; p++) {
+                const unsigned char qs = b->qs[p * 32 + lane];
+                int q0 = qs & 0xF;
+                int q1 = qs >> 4;
+                q0 += ((hb >> (2 * p)) & 1) << 4;
+                q1 += ((hb >> (2 * p + 1)) & 1) << 4;
+                const int idx0 = p * 64 + lane;
+                const int idx1 = idx0 + 32;
+                const int g0 = 2 * p;
+                const int g1 = g0 + 1;
+                sum += a_sh[idx0] * ((float)q0 * ds_sh[warp][g0] - dm_sh[warp][g0]);
+                sum += a_sh[idx1] * ((float)q1 * ds_sh[warp][g1] - dm_sh[warp][g1]);
+            }
+        }
+        __syncthreads();
+    }
+
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        sum += __shfl_down_sync(0xffffffff, sum, offset);
+    }
+    if (colIn && lane == 0) {
+        C[row * N + col] = sum;
+    }
+}
+
+int cuda_matmul_f16_q5k(const void* A, const void* B, float* C, int M, int K, int N) {
+    int blocksPerRow = K / 256;
+    dim3 threads(32, 8);
+    dim3 blocks((N + 7) / 8, M);
+    matmul_q5k_kernel_f16in<<<blocks, threads>>>((const __half*)A, (const BlockQ5_K*)B, C, M, K, N, blocksPerRow);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+__global__ void matmul_q2k_kernel_f16in(const __half* A, const BlockQ2_K* B, float* C,
+                                         int M, int K, int N, int blocksPerRow) {
+    const int row = blockIdx.y;
+    const int warp = threadIdx.y;
+    const int lane = threadIdx.x;
+    const int col = blockIdx.x * 8 + warp;
+
+    if (row >= M) return;
+    const bool colIn = (col < N);
+
+    float sum = 0.0f;
+
+    __shared__ float a_sh[256];
+    __shared__ float d_sh[8];
+    __shared__ float dmin_sh[8];
+    __shared__ unsigned char scales_sh[8][16];
+    __shared__ unsigned char qs_sh[8][64];
+
+    for (int blk = 0; blk < blocksPerRow; blk++) {
+        const int tid = warp * 32 + lane;
+        const __half* aRow = A + row * K + blk * 256;
+        a_sh[tid] = __half2float(aRow[tid]);
+
+        if (colIn) {
+            const BlockQ2_K* b = &B[col * blocksPerRow + blk];
+            if (lane == 0) {
+                d_sh[warp] = fp16_to_fp32(b->d);
+                dmin_sh[warp] = fp16_to_fp32(b->dmin);
+            }
+            if (lane < 16) {
+                scales_sh[warp][lane] = b->scales[lane];
+            }
+            qs_sh[warp][lane] = b->qs[lane];
+            qs_sh[warp][lane + 32] = b->qs[lane + 32];
+        }
+
+        __syncthreads();
+
+        if (colIn) {
+            const float d = d_sh[warp];
+            const float dmin = dmin_sh[warp];
+
+            #pragma unroll
+            for (int i = 0; i < 8; i++) {
+                const int idx = lane + (i * 32);
+                const int is = idx >> 5;
+                const int iq = idx & 31;
+                const int qsIdx = (is >> 2) * 32 + iq;
+                const int shift = (is & 3) * 2;
+                const int val = (qs_sh[warp][qsIdx] >> shift) & 3;
+                const int scIdx = (is >> 2) * 8 + (is & 3) * 2 + (iq >> 4);
+                const unsigned char sc = scales_sh[warp][scIdx];
+                const float dl = d * (float)(sc & 0xF);
+                const float ml = dmin * (float)(sc >> 4);
+                sum += a_sh[idx] * (dl * (float)val - ml);
+            }
+        }
+        __syncthreads();
+    }
+
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        sum += __shfl_down_sync(0xffffffff, sum, offset);
+    }
+    if (colIn && lane == 0) {
+        C[row * N + col] = sum;
+    }
+}
+
+int cuda_matmul_f16_q2k(const void* A, const void* B, float* C, int M, int K, int N) {
+    int blocksPerRow = K / 256;
+    dim3 threads(32, 8);
+    dim3 blocks((N + 7) / 8, M);
+    matmul_q2k_kernel_f16in<<<blocks, threads>>>((const __half*)A, (const BlockQ2_K*)B, C, M, K, N, blocksPerRow);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+__global__ void matmul_q3k_kernel_f16in(const __half* A, const BlockQ3_K* B, float* C,
+                                         int M, int K, int N, int blocksPerRow) {
+    const int row = blockIdx.y;
+    const int warp = threadIdx.y;
+    const int lane = threadIdx.x;
+    const int col = blockIdx.x * 8 + warp;
+
+    if (row >= M) return;
+    const bool colIn = (col < N);
+
+    float sum = 0.0f;
+
+    __shared__ float a_sh[256];
+    __shared__ float d_sh[8];
+    __shared__ unsigned char scales_sh[8][12];
+    __shared__ unsigned char qs_sh[8][64];
+    __shared__ unsigned char hmask_sh[8][32];
+
+    for (int blk = 0; blk < blocksPerRow; blk++) {
+        const int tid = warp * 32 + lane;
+        const __half* aRow = A + row * K + blk * 256;
+        a_sh[tid] = __half2float(aRow[tid]);
+
+        if (colIn) {
+            const BlockQ3_K* b = &B[col * blocksPerRow + blk];
+            if (lane == 0) {
+                d_sh[warp] = fp16_to_fp32(b->d);
+            }
+            if (lane < 12) {
+                scales_sh[warp][lane] = b->scales[lane];
+            }
+            qs_sh[warp][lane] = b->qs[lane];
+            qs_sh[warp][lane + 32] = b->qs[lane + 32];
+            hmask_sh[warp][lane] = b->hmask[lane];
+        }
+
+        __syncthreads();
+
+        if (colIn) {
+            const float d = d_sh[warp];
+
+            #pragma unroll
+            for (int i = 0; i < 8; i++) {
+                const int idx = lane + (i * 32);
+                const int is = idx >> 5;
+                const int iq = idx & 31;
+                const int qsIdx = (is >> 2) * 32 + iq;
+                const int shift = (is & 3) * 2;
+                int qv = (qs_sh[warp][qsIdx] >> shift) & 0x3;
+                const unsigned char m = (unsigned char)(1 << ((is >> 2) * 4 + (is & 3)));
+                if ((hmask_sh[warp][iq] & m) == 0) {
+                    qv -= 4;
+                }
+                const int sIdx = (is >> 2) * 8 + (is & 3) * 2 + (iq >> 4);
+                unsigned char sc;
+                if (sIdx < 8) {
+                    sc = scales_sh[warp][sIdx] & 0xF;
+                } else {
+                    sc = scales_sh[warp][sIdx - 8] >> 4;
+                }
+                sc |= ((scales_sh[warp][8 + (sIdx & 3)] >> (2 * (sIdx >> 2))) & 0x3) << 4;
+                const float scale = (float)((int)((signed char)sc) - 32);
+                sum += a_sh[idx] * (d * scale * (float)qv);
+            }
+        }
+        __syncthreads();
+    }
+
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        sum += __shfl_down_sync(0xffffffff, sum, offset);
+    }
+    if (colIn && lane == 0) {
+        C[row * N + col] = sum;
+    }
+}
+
+int cuda_matmul_f16_q3k(const void* A, const void* B, float* C, int M, int K, int N) {
+    int blocksPerRow = K / 256;
+    dim3 threads(32, 8);
+    dim3 blocks((N + 7) / 8, M);
+    matmul_q3k_kernel_f16in<<<blocks, threads>>>((const __half*)A, (const BlockQ3_K*)B, C, M, K, N, blocksPerRow);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+__global__ void matmul_q6k_kernel_f16in(const __half* A, const BlockQ6_K* B, float* C,
+                                         int M, int K, int N, int blocksPerRow) {
+    const int row = blockIdx.y;
+    const int warp = threadIdx.y;
+    const int lane = threadIdx.x;
+    const int col = blockIdx.x * 8 + warp;
+
+    if (row >= M) return;
+    const bool colIn = (col < N);
+
+    float sum = 0.0f;
+
+    __shared__ float a_sh[256];
+    __shared__ float d_sh[8];
+    __shared__ signed char scales_sh[8][16];
+    __shared__ unsigned char ql_sh[8][128];
+    __shared__ unsigned char qh_sh[8][64];
+
+    for (int blk = 0; blk < blocksPerRow; blk++) {
+        const int tid = warp * 32 + lane;
+        const __half* aRow = A + row * K + blk * 256;
+        a_sh[tid] = __half2float(aRow[tid]);
+
+        if (colIn) {
+            const BlockQ6_K* b = &B[col * blocksPerRow + blk];
+            if (lane == 0) {
+                d_sh[warp] = fp16_to_fp32(b->d);
+            }
+            if (lane < 16) {
+                scales_sh[warp][lane] = b->scales[lane];
+            }
+            qh_sh[warp][lane] = b->qh[lane];
+            qh_sh[warp][lane + 32] = b->qh[lane + 32];
+            ql_sh[warp][lane] = b->ql[lane];
+            ql_sh[warp][lane + 32] = b->ql[lane + 32];
+            ql_sh[warp][lane + 64] = b->ql[lane + 64];
+            ql_sh[warp][lane + 96] = b->ql[lane + 96];
+        }
+
+        __syncthreads();
+
+        if (colIn) {
+            const float d = d_sh[warp];
+
+            #pragma unroll
+            for (int i = 0; i < 8; i++) {
+                const int idx = lane + (i * 32);
+                const int is = idx >> 5;
+                const int iq = idx & 31;
+                const int qlIdx = (is >> 2) * 64 + (is & 1) * 32 + iq;
+                const int qhIdx = (is >> 2) * 32 + iq;
+                const int scIdx = (is >> 2) * 8 + (is & 3) * 2 + (iq >> 4);
+                const unsigned char ql = ql_sh[warp][qlIdx];
+                const unsigned char qh = qh_sh[warp][qhIdx];
+                const int shift_ql = ((is & 3) < 2) ? 0 : 4;
+                const int shift_qh = (is & 3) * 2;
+                int q = ((ql >> shift_ql) & 0xF) | (((qh >> shift_qh) & 3) << 4);
+                q -= 32;
+                sum += a_sh[idx] * (d * (float)scales_sh[warp][scIdx] * (float)q);
+            }
+        }
+        __syncthreads();
+    }
+
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        sum += __shfl_down_sync(0xffffffff, sum, offset);
+    }
+    if (colIn && lane == 0) {
+        C[row * N + col] = sum;
+    }
+}
+
+int cuda_matmul_f16_q6k(const void* A, const void* B, float* C, int M, int K, int N) {
+    int blocksPerRow = K / 256;
+    dim3 threads(32, 8);
+    dim3 blocks((N + 7) / 8, M);
+    matmul_q6k_kernel_f16in<<<blocks, threads>>>((const __half*)A, (const BlockQ6_K*)B, C, M, K, N, blocksPerRow);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}

+ 317 - 0
pkg/backend/cuda/cuda_memory.cu

@@ -0,0 +1,317 @@
+#include "cuda_common.cuh"
+#include <cstdlib>
+#include <cstring>
+#include <mutex>
+#include <unordered_map>
+#include <vector>
+
+// --- Memory ---
+
+ namespace {
+ 
+ struct free_block {
+     void * ptr;
+     size_t size;
+ };
+ 
+ struct device_pool {
+     std::mutex mu;
+     std::unordered_map<void *, size_t> alloc_sizes;
+     std::vector<free_block> free_list;
+     size_t cached_bytes = 0;
+ };
+
+ // Current CUDA device cached per host thread.
+ // This is updated by cuda_set_device and used by cuda_malloc/cuda_free.
+ static thread_local int tls_device = 0;
+ 
+ // Keep a small per-device cache of freed allocations to avoid cudaMalloc/cudaFree churn
+ // and to keep VRAM usage stable after first-touch allocations.
+static device_pool g_pools[16];
+static constexpr size_t MAX_FREE_BLOCKS_PER_DEVICE = 1024;
+static size_t g_pool_max_cached_bytes = 512ULL << 20;   // 512MB
+static size_t g_pool_max_block_bytes = 64ULL << 20;     // 64MB
+static bool g_pool_enabled = true;
+static std::once_flag g_pool_config_once;
+
+static size_t parse_env_bytes(const char * env, size_t def_val) {
+    if (env == nullptr || env[0] == '\0') {
+        return def_val;
+    }
+    char * end = nullptr;
+    unsigned long long val = std::strtoull(env, &end, 10);
+    if (end != nullptr && *end != '\0') {
+        switch (*end) {
+        case 'k':
+        case 'K':
+            val *= 1024ULL;
+            break;
+        case 'm':
+        case 'M':
+            val *= 1024ULL * 1024ULL;
+            break;
+        case 'g':
+        case 'G':
+            val *= 1024ULL * 1024ULL * 1024ULL;
+            break;
+        default:
+            break;
+        }
+    }
+    return static_cast<size_t>(val);
+}
+
+static bool env_true(const char * env) {
+    if (env == nullptr) {
+        return false;
+    }
+    if (std::strcmp(env, "1") == 0 || std::strcmp(env, "true") == 0 || std::strcmp(env, "TRUE") == 0) {
+        return true;
+    }
+    return false;
+}
+
+static void init_pool_config() {
+    std::call_once(g_pool_config_once, []() {
+        const char * disable = std::getenv("MAKARNA_CUDA_POOL_DISABLE");
+        if (env_true(disable)) {
+            g_pool_enabled = false;
+            g_pool_max_cached_bytes = 0;
+            g_pool_max_block_bytes = 0;
+            return;
+        }
+        const char * max_bytes = std::getenv("MAKARNA_CUDA_POOL_MAX_BYTES");
+        const char * max_block = std::getenv("MAKARNA_CUDA_POOL_MAX_BLOCK_BYTES");
+        g_pool_max_cached_bytes = parse_env_bytes(max_bytes, g_pool_max_cached_bytes);
+        g_pool_max_block_bytes = parse_env_bytes(max_block, g_pool_max_block_bytes);
+    });
+}
+ 
+ static device_pool & pool_for(int device) {
+     if (device < 0) device = 0;
+     if (device >= 16) device = device % 16;
+     return g_pools[device];
+ }
+ 
+static void * pool_alloc(int device, size_t size) {
+    init_pool_config();
+    device_pool & p = pool_for(device);
+    std::lock_guard<std::mutex> lock(p.mu);
+ 
+     // Best-fit search: pick the smallest block that satisfies the request.
+     size_t best_i = (size_t) -1;
+     size_t best_size = (size_t) -1;
+     for (size_t i = 0; i < p.free_list.size(); ++i) {
+         const free_block & b = p.free_list[i];
+         if (b.size >= size && b.size < best_size) {
+             best_i = i;
+             best_size = b.size;
+         }
+     }
+    if (best_i != (size_t) -1) {
+        void * ptr = p.free_list[best_i].ptr;
+        size_t bsize = p.free_list[best_i].size;
+        // erase by swap-with-back
+        p.free_list[best_i] = p.free_list.back();
+        p.free_list.pop_back();
+        if (p.cached_bytes >= bsize) {
+            p.cached_bytes -= bsize;
+        } else {
+            p.cached_bytes = 0;
+        }
+        return ptr;
+    }
+ 
+     return nullptr;
+ }
+ 
+ static void pool_record_alloc(int device, void * ptr, size_t size) {
+     if (ptr == nullptr) return;
+     device_pool & p = pool_for(device);
+     std::lock_guard<std::mutex> lock(p.mu);
+     p.alloc_sizes[ptr] = size;
+ }
+ 
+ static size_t pool_lookup_size(int device, void * ptr) {
+     device_pool & p = pool_for(device);
+     std::lock_guard<std::mutex> lock(p.mu);
+     auto it = p.alloc_sizes.find(ptr);
+     if (it == p.alloc_sizes.end()) {
+         return 0;
+     }
+     return it->second;
+ }
+
+ static int pool_find_device(void * ptr, size_t * out_size) {
+     if (out_size) *out_size = 0;
+     if (ptr == nullptr) return -1;
+     for (int d = 0; d < 16; ++d) {
+         device_pool & p = g_pools[d];
+         std::lock_guard<std::mutex> lock(p.mu);
+         auto it = p.alloc_sizes.find(ptr);
+         if (it != p.alloc_sizes.end()) {
+             if (out_size) *out_size = it->second;
+             return d;
+         }
+     }
+     return -1;
+ }
+ 
+static void pool_free(int device, void * ptr) {
+    init_pool_config();
+    if (ptr == nullptr) return;
+    size_t size = pool_lookup_size(device, ptr);
+     int actual_device = device;
+     if (size == 0) {
+         int found = pool_find_device(ptr, &size);
+         if (found >= 0) {
+             actual_device = found;
+         }
+     }
+
+    device_pool & p = pool_for(actual_device);
+    std::lock_guard<std::mutex> lock(p.mu);
+    if (!g_pool_enabled || g_pool_max_cached_bytes == 0 || g_pool_max_block_bytes == 0 || size > g_pool_max_block_bytes) {
+        cudaSetDevice(actual_device);
+        cudaFree(ptr);
+        p.alloc_sizes.erase(ptr);
+        return;
+    }
+    if (p.free_list.size() >= MAX_FREE_BLOCKS_PER_DEVICE || p.cached_bytes+size > g_pool_max_cached_bytes) {
+        // Pool full: actually free.
+        cudaSetDevice(actual_device);
+        cudaFree(ptr);
+        p.alloc_sizes.erase(ptr);
+        return;
+    }
+    p.free_list.push_back(free_block{ptr, size});
+    p.cached_bytes += size;
+}
+ 
+ } // namespace
+
+int cuda_set_device(int id) {
+    // cudaSetDevice is expensive when called repeatedly.
+    // Cache per host thread since CUDA device context is thread-affine.
+    if (tls_device == id) {
+        return 0;
+    }
+    CHECK_CUDA(cudaSetDevice(id));
+    tls_device = id;
+    return 0;
+}
+
+void* cuda_malloc(size_t size) {
+    init_pool_config();
+    const int device = tls_device;
+    void * ptr = pool_alloc(device, size);
+    if (ptr != nullptr) {
+        return ptr;
+    }
+
+    ptr = NULL;
+    if (cudaMalloc(&ptr, size) != cudaSuccess) {
+        return NULL;
+    }
+    pool_record_alloc(device, ptr, size);
+    return ptr;
+}
+
+void cuda_free(void* ptr) {
+    const int device = tls_device;
+    pool_free(device, ptr);
+}
+
+int cuda_synchronize() {
+    CHECK_CUDA(cudaDeviceSynchronize());
+    return 0;
+}
+
+int cuda_memcpy_h2d(void* dst, void* src, size_t size) {
+    CHECK_CUDA(cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice));
+    return 0;
+}
+
+int cuda_memcpy_d2h(void* dst, void* src, size_t size) {
+    CHECK_CUDA(cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost));
+    return 0;
+}
+
+int cuda_memcpy_d2d(void* dst, void* src, size_t size) {
+    CHECK_CUDA(cudaMemcpy(dst, src, size, cudaMemcpyDeviceToDevice));
+    return 0;
+}
+
+int cuda_mem_info(size_t* free_bytes, size_t* total_bytes) {
+    // cudaMemGetInfo can return cudaErrorOperatingSystem in some restricted
+    // environments even though allocations/kernels work. Fall back to device
+    // properties so higher-level placement logic can still function.
+    cudaError_t err = cudaMemGetInfo(free_bytes, total_bytes);
+    if (err == cudaSuccess) {
+        return 0;
+    }
+    if (err == cudaErrorOperatingSystem) {
+        // Some sandboxes block driver queries (MemGetInfo/GetDeviceProperties)
+        // but still allow allocations. Approximate "free" with a probing alloc.
+        (void)cudaGetLastError();
+
+        size_t max_ok = 0;
+        size_t probe = 256ULL << 20; // 256MB
+        const size_t max_probe = 64ULL << 30; // 64GB cap
+
+        void* p = nullptr;
+        while (probe <= max_probe) {
+            cudaError_t e = cudaMalloc(&p, probe);
+            if (e == cudaSuccess) {
+                (void)cudaFree(p);
+                p = nullptr;
+                max_ok = probe;
+                probe <<= 1;
+                continue;
+            }
+            (void)cudaGetLastError();
+            break;
+        }
+
+        size_t lo = max_ok;
+        size_t hi = probe;
+        // Binary search to 64MB granularity.
+        const size_t gran = 64ULL << 20;
+        while (hi > lo + gran) {
+            size_t mid = lo + (hi - lo) / 2;
+            mid = (mid / (1ULL << 20)) * (1ULL << 20); // align to 1MB
+            if (mid <= lo) {
+                break;
+            }
+            cudaError_t e = cudaMalloc(&p, mid);
+            if (e == cudaSuccess) {
+                (void)cudaFree(p);
+                p = nullptr;
+                lo = mid;
+            } else {
+                (void)cudaGetLastError();
+                hi = mid;
+            }
+        }
+
+        if (free_bytes) {
+            *free_bytes = lo;
+        }
+        if (total_bytes) {
+            *total_bytes = lo;
+        }
+        return 0;
+    }
+
+    fprintf(stderr, "CUDA error at %s:%d: %s\n", __FILE__, __LINE__, cudaGetErrorString(err));
+    return 1;
+}
+
+int cuda_device_count(int* count) {
+    int c = 0;
+    CHECK_CUDA(cudaGetDeviceCount(&c));
+    if (count) {
+        *count = c;
+    }
+    return 0;
+}

+ 2165 - 0
pkg/backend/cuda/cuda_nn.cu

@@ -0,0 +1,2165 @@
+#include "cuda_common.cuh"
+#include <cuda_fp16.h>
+#include <stdint.h>
+#include <string.h>
+
+namespace {
+constexpr int kPagedAttentionSplitSize = 1024;
+constexpr int kPagedAttentionSplitQHThreshold = 4096; // queryCount*numHeads threshold
+} // namespace
+
+// ============================================================
+// Fused RoPE helpers (constant step table + fast complex pow)
+// ============================================================
+
+// Stores cos/sin of per-dimension "step" angle (invFreq) for RoPE.
+// Only indices [0, headDim/2) are used. Max supported headDim is 256.
+__device__ __constant__ float rope_cos_step_const[128];
+__device__ __constant__ float rope_sin_step_const[128];
+
+static int g_rope_step_inited[32] = {0};
+static int g_rope_step_head_dim[32] = {0};
+static uint32_t g_rope_step_theta_bits[32] = {0};
+
+static int ensure_rope_step_table(int headDim, float theta) {
+    if (headDim <= 0 || headDim > 256 || (headDim & 1) != 0) {
+        return 1;
+    }
+
+    int dev = 0;
+    CHECK_CUDA(cudaGetDevice(&dev));
+    if (dev < 0 || dev >= 32) {
+        dev = 0;
+    }
+
+    uint32_t thetaBits = 0;
+    memcpy(&thetaBits, &theta, sizeof(thetaBits));
+    if (g_rope_step_inited[dev] && g_rope_step_head_dim[dev] == headDim && g_rope_step_theta_bits[dev] == thetaBits) {
+        return 0;
+    }
+
+    float cosStep[128];
+    float sinStep[128];
+    const int halfDim = headDim / 2;
+    for (int j = 0; j < 128; j++) {
+        if (j < halfDim) {
+            // invFreq = theta^(-2j/headDim)
+            const double exp = -2.0 * (double)j / (double)headDim;
+            const double invFreq = pow((double)theta, exp);
+            cosStep[j] = (float)cos(invFreq);
+            sinStep[j] = (float)sin(invFreq);
+        } else {
+            cosStep[j] = 1.0f;
+            sinStep[j] = 0.0f;
+        }
+    }
+
+    CHECK_CUDA(cudaMemcpyToSymbol(rope_cos_step_const, cosStep, sizeof(cosStep), 0, cudaMemcpyHostToDevice));
+    CHECK_CUDA(cudaMemcpyToSymbol(rope_sin_step_const, sinStep, sizeof(sinStep), 0, cudaMemcpyHostToDevice));
+
+    g_rope_step_inited[dev] = 1;
+    g_rope_step_head_dim[dev] = headDim;
+    g_rope_step_theta_bits[dev] = thetaBits;
+    return 0;
+}
+
+__device__ __forceinline__ float2 complex_mul_f2(float2 a, float2 b) {
+    // (a.x + i a.y) * (b.x + i b.y)
+    return make_float2(
+        fmaf(a.x, b.x, -a.y * b.y),
+        fmaf(a.x, b.y, a.y * b.x)
+    );
+}
+
+__device__ __forceinline__ float2 complex_pow_int(float2 base, int exp) {
+    float2 result = make_float2(1.0f, 0.0f);
+    float2 b = base;
+    int e = exp;
+    while (e > 0) {
+        if (e & 1) {
+            result = complex_mul_f2(result, b);
+        }
+        b = complex_mul_f2(b, b);
+        e >>= 1;
+    }
+    return result;
+}
+
+__device__ __forceinline__ void rope_advance_neg(float& cosv, float& sinv, float cosStep, float sinStep) {
+    // Multiply by exp(-i*step): (cos + i sin) * (cosStep - i sinStep)
+    const float c = cosv;
+    const float s = sinv;
+    cosv = fmaf(c, cosStep, s * sinStep);
+    sinv = fmaf(s, cosStep, -c * sinStep);
+}
+
+// ============================================================
+// Neural Network Operations
+// ============================================================
+
+// RMSNorm kernel: one block per row
+__global__ void rmsnorm_kernel(float* x, const float* w, int dim, float eps) {
+    int row = blockIdx.x;
+    float* rowData = x + row * dim;
+
+    float sum = 0.0f;
+    for (int i = threadIdx.x; i < dim; i += blockDim.x) {
+        float v = rowData[i];
+        sum = fmaf(v, v, sum);
+    }
+
+    // Warp reduce
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        sum += __shfl_down_sync(0xffffffff, sum, offset);
+    }
+
+    __shared__ float warpSum[8];
+    __shared__ float rms;
+
+    int lane = threadIdx.x & 31;
+    int warp = threadIdx.x >> 5;
+    if (lane == 0) {
+        warpSum[warp] = sum;
+    }
+    __syncthreads();
+
+    if (warp == 0) {
+        float v = (lane < 8) ? warpSum[lane] : 0.0f;
+        for (int offset = 16; offset > 0; offset >>= 1) {
+            v += __shfl_down_sync(0xffffffff, v, offset);
+        }
+        if (lane == 0) {
+            rms = rsqrtf(v / dim + eps);
+        }
+    }
+    __syncthreads();
+
+    for (int i = threadIdx.x; i < dim; i += blockDim.x) {
+        rowData[i] = rowData[i] * rms * w[i];
+    }
+}
+
+__global__ void paged_attention_batch_kernel_f16kv(
+    const float* Q,
+    const unsigned short* const* KBlocksFlat,
+    const unsigned short* const* VBlocksFlat,
+    const int* blockOffsets,
+    const int* kvLens,
+    const int* queryPos,
+    float* out,
+    int numTokens,
+    int numHeads, int numKVHeads, int headDim,
+    int blockSize,
+    float scale
+) {
+    const int tok = blockIdx.x;
+    const int head = blockIdx.y;
+    const int lane = threadIdx.x & 31;
+    if (tok >= numTokens) {
+        return;
+    }
+    // One warp per (tok, head)
+    if (threadIdx.x >= 32) {
+        return;
+    }
+
+    const int kvHead = head / (numHeads / numKVHeads);
+    const float* q = Q + tok * numHeads * headDim + head * headDim;
+    float* o = out + tok * numHeads * headDim + head * headDim;
+
+    const int kvLen = kvLens[tok];
+    const int qPos = queryPos[tok];
+    const int base = blockOffsets[tok];
+
+    const int kvStride = numKVHeads * headDim;
+    const int effectiveLen = (kvLen < (qPos + 1)) ? kvLen : (qPos + 1);
+
+    // Cache Q in registers (per lane) to avoid reloading it for every KV token.
+    float qreg[8];
+    #pragma unroll
+    for (int i = 0; i < 8; i++) {
+        const int d = lane + 32 * i;
+        qreg[i] = (d < headDim) ? q[d] : 0.0f;
+    }
+
+    // Support headDim up to 256 (<= 8 values per lane)
+    float acc[8];
+    #pragma unroll
+    for (int i = 0; i < 8; i++) {
+        acc[i] = 0.0f;
+    }
+
+    float m = -INFINITY;
+    float l = 0.0f;
+
+    for (int kv = 0; kv < effectiveLen; kv++) {
+        const int bidx = kv / blockSize;
+        const int boff = kv % blockSize;
+
+        const __half* kBlock = reinterpret_cast<const __half*>(KBlocksFlat[base + bidx]);
+        const __half* k = kBlock + boff * kvStride + kvHead * headDim;
+
+        float dot = 0.0f;
+        #pragma unroll
+        for (int i = 0; i < 8; i++) {
+            const int d = lane + 32 * i;
+            if (d < headDim) {
+                dot = fmaf(qreg[i], __half2float(k[d]), dot);
+            }
+        }
+        for (int offset = 16; offset > 0; offset >>= 1) {
+            dot += __shfl_down_sync(0xffffffff, dot, offset);
+        }
+        dot = __shfl_sync(0xffffffff, dot, 0);
+
+        const float score = dot * scale;
+        float alpha = 1.0f;
+        float beta = 0.0f;
+        if (lane == 0) {
+            const float newM = fmaxf(m, score);
+            alpha = __expf(m - newM);
+            beta = __expf(score - newM);
+            m = newM;
+            l = l * alpha + beta;
+        }
+        alpha = __shfl_sync(0xffffffff, alpha, 0);
+        beta = __shfl_sync(0xffffffff, beta, 0);
+        l = __shfl_sync(0xffffffff, l, 0);
+
+        const __half* vBlock = reinterpret_cast<const __half*>(VBlocksFlat[base + bidx]);
+        const __half* v = vBlock + boff * kvStride + kvHead * headDim;
+        #pragma unroll
+        for (int i = 0; i < 8; i++) {
+            const int d = lane + 32 * i;
+            if (d < headDim) {
+                acc[i] = acc[i] * alpha + beta * __half2float(v[d]);
+            }
+        }
+    }
+
+    const float invL = (l > 0.0f) ? (1.0f / l) : 0.0f;
+    #pragma unroll
+    for (int i = 0; i < 8; i++) {
+        const int d = lane + 32 * i;
+        if (d < headDim) {
+            o[d] = acc[i] * invL;
+        }
+    }
+}
+
+// Fused RoPE + paged attention (batch, f16 KV). Expects un-rotated Q/K.
+__global__ void paged_attention_batch_kernel_f16kv_rope(
+    const float* Q,
+    const unsigned short* const* KBlocksFlat,
+    const unsigned short* const* VBlocksFlat,
+    const int* blockOffsets,
+    const int* kvLens,
+    const int* queryPos,
+    float* out,
+    int numTokens,
+    int numHeads, int numKVHeads, int headDim,
+    int blockSize,
+    float scale
+) {
+    const int tok = blockIdx.x;
+    const int head = blockIdx.y;
+    const int lane = threadIdx.x & 31;
+    if (tok >= numTokens) {
+        return;
+    }
+    // One warp per (tok, head)
+    if (threadIdx.x >= 32) {
+        return;
+    }
+
+    const int kvHead = head / (numHeads / numKVHeads);
+    const float* q = Q + tok * numHeads * headDim + head * headDim;
+    float* o = out + tok * numHeads * headDim + head * headDim;
+
+    const int kvLen = kvLens[tok];
+    const int qPos = queryPos[tok];
+    const int base = blockOffsets[tok];
+
+    const int kvStride = numKVHeads * headDim;
+    const int effectiveLen = (kvLen < (qPos + 1)) ? kvLen : (qPos + 1);
+    const int halfDim = headDim >> 1;
+
+    // Cache Q pairs + per-dim RoPE phase for delta=(qPos - kv) with kv starting at 0.
+    float q0[4];
+    float q1[4];
+    float cosStep[4];
+    float sinStep[4];
+    float cosDelta[4];
+    float sinDelta[4];
+    int pairCount = 0;
+    for (int j = lane; j < halfDim; j += 32) {
+        q0[pairCount] = q[j];
+        q1[pairCount] = q[j + halfDim];
+        cosStep[pairCount] = rope_cos_step_const[j];
+        sinStep[pairCount] = rope_sin_step_const[j];
+        const float2 baseStep = make_float2(cosStep[pairCount], sinStep[pairCount]);
+        const float2 ph = complex_pow_int(baseStep, qPos);
+        cosDelta[pairCount] = ph.x;
+        sinDelta[pairCount] = ph.y;
+        pairCount++;
+    }
+
+    // Support headDim up to 256.
+    float acc[8];
+    #pragma unroll
+    for (int i = 0; i < 8; i++) {
+        acc[i] = 0.0f;
+    }
+
+    float m = -INFINITY;
+    float l = 0.0f;
+
+    for (int kv = 0; kv < effectiveLen; kv++) {
+        const int bidx = kv / blockSize;
+        const int boff = kv % blockSize;
+
+        const __half* kBlock = reinterpret_cast<const __half*>(KBlocksFlat[base + bidx]);
+        const __half* k = kBlock + boff * kvStride + kvHead * headDim;
+
+        float dot = 0.0f;
+        #pragma unroll
+        for (int pi = 0; pi < 4; pi++) {
+            if (pi >= pairCount) {
+                break;
+            }
+            const int j = lane + 32 * pi;
+            if (j >= halfDim) {
+                continue;
+            }
+            const float k0 = __half2float(k[j]);
+            const float k1 = __half2float(k[j + halfDim]);
+            const float a = fmaf(q0[pi], k0, q1[pi] * k1);         // q0*k0 + q1*k1
+            const float b = fmaf(q0[pi], k1, -q1[pi] * k0);        // q0*k1 - q1*k0
+            dot = fmaf(cosDelta[pi], a, dot);
+            dot = fmaf(sinDelta[pi], b, dot);
+        }
+        for (int offset = 16; offset > 0; offset >>= 1) {
+            dot += __shfl_down_sync(0xffffffff, dot, offset);
+        }
+        dot = __shfl_sync(0xffffffff, dot, 0);
+
+        // Advance delta -> delta-1 for next kv.
+        #pragma unroll
+        for (int pi = 0; pi < 4; pi++) {
+            if (pi >= pairCount) {
+                break;
+            }
+            rope_advance_neg(cosDelta[pi], sinDelta[pi], cosStep[pi], sinStep[pi]);
+        }
+
+        const float score = dot * scale;
+        float alpha = 1.0f;
+        float beta = 0.0f;
+        if (lane == 0) {
+            const float newM = fmaxf(m, score);
+            alpha = __expf(m - newM);
+            beta = __expf(score - newM);
+            m = newM;
+            l = l * alpha + beta;
+        }
+        alpha = __shfl_sync(0xffffffff, alpha, 0);
+        beta = __shfl_sync(0xffffffff, beta, 0);
+        l = __shfl_sync(0xffffffff, l, 0);
+
+        const __half* vBlock = reinterpret_cast<const __half*>(VBlocksFlat[base + bidx]);
+        const __half* v = vBlock + boff * kvStride + kvHead * headDim;
+        #pragma unroll
+        for (int i = 0; i < 8; i++) {
+            const int d = lane + 32 * i;
+            if (d < headDim) {
+                acc[i] = acc[i] * alpha + beta * __half2float(v[d]);
+            }
+        }
+    }
+
+    const float invL = (l > 0.0f) ? (1.0f / l) : 0.0f;
+    #pragma unroll
+    for (int i = 0; i < 8; i++) {
+        const int d = lane + 32 * i;
+        if (d < headDim) {
+            o[d] = acc[i] * invL;
+        }
+    }
+}
+
+__global__ void cast_f32_to_f16_kernel(const float* src, __half* dst, int n) {
+    int i = blockIdx.x * blockDim.x + threadIdx.x;
+    if (i < n) {
+        dst[i] = __float2half_rn(src[i]);
+    }
+}
+
+__global__ void paged_attention_kernel_f16kv(
+    const float* Q,
+    const unsigned short* const* KBlocks,
+    const unsigned short* const* VBlocks,
+    float* out,
+    int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
+    int blockSize,
+    float scale, int startPos
+) {
+    const int seq = blockIdx.x;
+    const int head = blockIdx.y;
+    const int lane = threadIdx.x & 31;
+    if (seq >= seqLen) {
+        return;
+    }
+    // One warp per (seq, head)
+    if (threadIdx.x >= 32) {
+        return;
+    }
+
+    const int kvHead = head / (numHeads / numKVHeads);
+    const float* q = Q + seq * numHeads * headDim + head * headDim;
+    float* o = out + seq * numHeads * headDim + head * headDim;
+
+    const int kvStride = numKVHeads * headDim;
+    const int queryPos = startPos + seq;
+    const int effectiveLen = (kvLen < (queryPos + 1)) ? kvLen : (queryPos + 1);
+
+    // Cache Q in registers (per lane) to avoid reloading it for every KV token.
+    float qreg[8];
+    #pragma unroll
+    for (int i = 0; i < 8; i++) {
+        const int d = lane + 32 * i;
+        qreg[i] = (d < headDim) ? q[d] : 0.0f;
+    }
+
+    float acc[8];
+    #pragma unroll
+    for (int i = 0; i < 8; i++) {
+        acc[i] = 0.0f;
+    }
+    float m = -INFINITY;
+    float l = 0.0f;
+
+    for (int kv = 0; kv < effectiveLen; kv++) {
+        const int blockIdxKV = kv / blockSize;
+        const int blockOff = kv % blockSize;
+
+        const __half* kBlock = reinterpret_cast<const __half*>(KBlocks[blockIdxKV]);
+        const __half* k = kBlock + blockOff * kvStride + kvHead * headDim;
+
+        float dot = 0.0f;
+        #pragma unroll
+        for (int i = 0; i < 8; i++) {
+            const int d = lane + 32 * i;
+            if (d < headDim) {
+                dot = fmaf(qreg[i], __half2float(k[d]), dot);
+            }
+        }
+        for (int offset = 16; offset > 0; offset >>= 1) {
+            dot += __shfl_down_sync(0xffffffff, dot, offset);
+        }
+        dot = __shfl_sync(0xffffffff, dot, 0);
+
+        const float score = dot * scale;
+        float alpha = 1.0f;
+        float beta = 0.0f;
+        if (lane == 0) {
+            const float newM = fmaxf(m, score);
+            alpha = __expf(m - newM);
+            beta = __expf(score - newM);
+            m = newM;
+            l = l * alpha + beta;
+        }
+        alpha = __shfl_sync(0xffffffff, alpha, 0);
+        beta = __shfl_sync(0xffffffff, beta, 0);
+        l = __shfl_sync(0xffffffff, l, 0);
+
+        const __half* vBlock = reinterpret_cast<const __half*>(VBlocks[blockIdxKV]);
+        const __half* v = vBlock + blockOff * kvStride + kvHead * headDim;
+        #pragma unroll
+        for (int i = 0; i < 8; i++) {
+            const int d = lane + 32 * i;
+            if (d < headDim) {
+                acc[i] = acc[i] * alpha + beta * __half2float(v[d]);
+            }
+        }
+    }
+
+    const float invL = (l > 0.0f) ? (1.0f / l) : 0.0f;
+    #pragma unroll
+    for (int i = 0; i < 8; i++) {
+        const int d = lane + 32 * i;
+        if (d < headDim) {
+            o[d] = acc[i] * invL;
+        }
+    }
+}
+
+// Fused RoPE + paged attention (single, f16 KV). Expects un-rotated Q/K.
+__global__ void paged_attention_kernel_f16kv_rope(
+    const float* Q,
+    const unsigned short* const* KBlocks,
+    const unsigned short* const* VBlocks,
+    float* out,
+    int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
+    int blockSize,
+    float scale, int startPos
+) {
+    const int seq = blockIdx.x;
+    const int head = blockIdx.y;
+    const int lane = threadIdx.x & 31;
+    if (seq >= seqLen) {
+        return;
+    }
+    // One warp per (seq, head)
+    if (threadIdx.x >= 32) {
+        return;
+    }
+
+    const int kvHead = head / (numHeads / numKVHeads);
+    const float* q = Q + seq * numHeads * headDim + head * headDim;
+    float* o = out + seq * numHeads * headDim + head * headDim;
+
+    const int kvStride = numKVHeads * headDim;
+    const int queryPos = startPos + seq;
+    const int effectiveLen = (kvLen < (queryPos + 1)) ? kvLen : (queryPos + 1);
+    const int halfDim = headDim >> 1;
+
+    float q0[4];
+    float q1[4];
+    float cosStep[4];
+    float sinStep[4];
+    float cosDelta[4];
+    float sinDelta[4];
+    int pairCount = 0;
+    for (int j = lane; j < halfDim; j += 32) {
+        q0[pairCount] = q[j];
+        q1[pairCount] = q[j + halfDim];
+        cosStep[pairCount] = rope_cos_step_const[j];
+        sinStep[pairCount] = rope_sin_step_const[j];
+        const float2 baseStep = make_float2(cosStep[pairCount], sinStep[pairCount]);
+        const float2 ph = complex_pow_int(baseStep, queryPos);
+        cosDelta[pairCount] = ph.x;
+        sinDelta[pairCount] = ph.y;
+        pairCount++;
+    }
+
+    float acc[8];
+    #pragma unroll
+    for (int i = 0; i < 8; i++) {
+        acc[i] = 0.0f;
+    }
+    float m = -INFINITY;
+    float l = 0.0f;
+
+    for (int kv = 0; kv < effectiveLen; kv++) {
+        const int blockIdxKV = kv / blockSize;
+        const int blockOff = kv % blockSize;
+
+        const __half* kBlock = reinterpret_cast<const __half*>(KBlocks[blockIdxKV]);
+        const __half* k = kBlock + blockOff * kvStride + kvHead * headDim;
+
+        float dot = 0.0f;
+        #pragma unroll
+        for (int pi = 0; pi < 4; pi++) {
+            if (pi >= pairCount) {
+                break;
+            }
+            const int j = lane + 32 * pi;
+            if (j >= halfDim) {
+                continue;
+            }
+            const float k0 = __half2float(k[j]);
+            const float k1 = __half2float(k[j + halfDim]);
+            const float a = fmaf(q0[pi], k0, q1[pi] * k1);
+            const float b = fmaf(q0[pi], k1, -q1[pi] * k0);
+            dot = fmaf(cosDelta[pi], a, dot);
+            dot = fmaf(sinDelta[pi], b, dot);
+        }
+        for (int offset = 16; offset > 0; offset >>= 1) {
+            dot += __shfl_down_sync(0xffffffff, dot, offset);
+        }
+        dot = __shfl_sync(0xffffffff, dot, 0);
+
+        // Advance delta -> delta-1 for next kv.
+        #pragma unroll
+        for (int pi = 0; pi < 4; pi++) {
+            if (pi >= pairCount) {
+                break;
+            }
+            rope_advance_neg(cosDelta[pi], sinDelta[pi], cosStep[pi], sinStep[pi]);
+        }
+
+        const float score = dot * scale;
+        float alpha = 1.0f;
+        float beta = 0.0f;
+        if (lane == 0) {
+            const float newM = fmaxf(m, score);
+            alpha = __expf(m - newM);
+            beta = __expf(score - newM);
+            m = newM;
+            l = l * alpha + beta;
+        }
+        alpha = __shfl_sync(0xffffffff, alpha, 0);
+        beta = __shfl_sync(0xffffffff, beta, 0);
+        l = __shfl_sync(0xffffffff, l, 0);
+
+        const __half* vBlock = reinterpret_cast<const __half*>(VBlocks[blockIdxKV]);
+        const __half* v = vBlock + blockOff * kvStride + kvHead * headDim;
+        #pragma unroll
+        for (int i = 0; i < 8; i++) {
+            const int d = lane + 32 * i;
+            if (d < headDim) {
+                acc[i] = acc[i] * alpha + beta * __half2float(v[d]);
+            }
+        }
+    }
+
+    const float invL = (l > 0.0f) ? (1.0f / l) : 0.0f;
+    #pragma unroll
+    for (int i = 0; i < 8; i++) {
+        const int d = lane + 32 * i;
+        if (d < headDim) {
+            o[d] = acc[i] * invL;
+        }
+    }
+}
+
+template <typename T>
+__device__ __forceinline__ float load_kv(const T* p, int idx) {
+    return p[idx];
+}
+
+template <>
+__device__ __forceinline__ float load_kv<__half>(const __half* p, int idx) {
+    return __half2float(p[idx]);
+}
+
+template <typename T, bool kUseRoPE = false>
+__global__ void paged_attention_split_kv_kernel(
+    const float* Q,
+    const T* const* KBlocks,
+    const T* const* VBlocks,
+    float* partialMax,
+    float* partialSum,
+    float* partialOut,
+    int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
+    int blockSize,
+    float scale, int startPos,
+    int numSplits, int splitSize
+) {
+    const int seq = blockIdx.x;
+    const int head = blockIdx.y;
+    const int split = blockIdx.z;
+    const int lane = threadIdx.x & 31;
+    if (seq >= seqLen) {
+        return;
+    }
+    // One warp per (seq, head, split).
+    if (threadIdx.x >= 32) {
+        return;
+    }
+
+    const int kvHead = head / (numHeads / numKVHeads);
+    const float* q = Q + seq * numHeads * headDim + head * headDim;
+
+    const int kvStride = numKVHeads * headDim;
+    const int queryPos = startPos + seq;
+    const int effectiveLen = (kvLen < (queryPos + 1)) ? kvLen : (queryPos + 1);
+
+    const int splitStart = split * splitSize;
+    const int splitEnd = (splitStart + splitSize < effectiveLen) ? (splitStart + splitSize) : effectiveLen;
+
+    const size_t qh = (size_t)seq * (size_t)numHeads + (size_t)head;
+    const size_t splitIdx = qh * (size_t)numSplits + (size_t)split;
+    float* outVec = partialOut + splitIdx * (size_t)headDim;
+
+    if (splitStart >= splitEnd) {
+        if (lane == 0) {
+            partialMax[splitIdx] = -INFINITY;
+            partialSum[splitIdx] = 0.0f;
+        }
+        #pragma unroll
+        for (int i = 0; i < 8; i++) {
+            const int d = lane + 32 * i;
+            if (d < headDim) {
+                outVec[d] = 0.0f;
+            }
+        }
+        return;
+    }
+
+    const int halfDim = headDim >> 1;
+    const int ropeExp = queryPos - splitStart;
+
+    float qreg[8];
+    float q0[4];
+    float q1[4];
+    float cosStep[4];
+    float sinStep[4];
+    float cosDelta[4];
+    float sinDelta[4];
+    int pairCount = 0;
+    if (kUseRoPE) {
+        for (int j = lane; j < halfDim; j += 32) {
+            q0[pairCount] = q[j];
+            q1[pairCount] = q[j + halfDim];
+            cosStep[pairCount] = rope_cos_step_const[j];
+            sinStep[pairCount] = rope_sin_step_const[j];
+            const float2 baseStep = make_float2(cosStep[pairCount], sinStep[pairCount]);
+            const float2 ph = complex_pow_int(baseStep, ropeExp);
+            cosDelta[pairCount] = ph.x;
+            sinDelta[pairCount] = ph.y;
+            pairCount++;
+        }
+    } else {
+        #pragma unroll
+        for (int i = 0; i < 8; i++) {
+            const int d = lane + 32 * i;
+            qreg[i] = (d < headDim) ? q[d] : 0.0f;
+        }
+    }
+
+    float acc[8];
+    #pragma unroll
+    for (int i = 0; i < 8; i++) {
+        acc[i] = 0.0f;
+    }
+    float m = -INFINITY;
+    float l = 0.0f;
+
+    for (int kv = splitStart; kv < splitEnd; kv++) {
+        const int blockIdxKV = kv / blockSize;
+        const int blockOff = kv % blockSize;
+
+        const T* kBlock = KBlocks[blockIdxKV];
+        const T* k = kBlock + blockOff * kvStride + kvHead * headDim;
+
+        float dot = 0.0f;
+        if (kUseRoPE) {
+            #pragma unroll
+            for (int pi = 0; pi < 4; pi++) {
+                if (pi >= pairCount) {
+                    break;
+                }
+                const int j = lane + 32 * pi;
+                if (j >= halfDim) {
+                    continue;
+                }
+                const float k0 = load_kv<T>(k, j);
+                const float k1 = load_kv<T>(k, j + halfDim);
+                const float a = fmaf(q0[pi], k0, q1[pi] * k1);
+                const float b = fmaf(q0[pi], k1, -q1[pi] * k0);
+                dot = fmaf(cosDelta[pi], a, dot);
+                dot = fmaf(sinDelta[pi], b, dot);
+            }
+        } else {
+            #pragma unroll
+            for (int i = 0; i < 8; i++) {
+                const int d = lane + 32 * i;
+                if (d < headDim) {
+                    dot = fmaf(qreg[i], load_kv<T>(k, d), dot);
+                }
+            }
+        }
+        for (int offset = 16; offset > 0; offset >>= 1) {
+            dot += __shfl_down_sync(0xffffffff, dot, offset);
+        }
+        dot = __shfl_sync(0xffffffff, dot, 0);
+
+        if (kUseRoPE) {
+            #pragma unroll
+            for (int pi = 0; pi < 4; pi++) {
+                if (pi >= pairCount) {
+                    break;
+                }
+                rope_advance_neg(cosDelta[pi], sinDelta[pi], cosStep[pi], sinStep[pi]);
+            }
+        }
+
+        const float score = dot * scale;
+        float alpha = 1.0f;
+        float beta = 0.0f;
+        if (lane == 0) {
+            const float newM = fmaxf(m, score);
+            alpha = __expf(m - newM);
+            beta = __expf(score - newM);
+            m = newM;
+            l = l * alpha + beta;
+        }
+        alpha = __shfl_sync(0xffffffff, alpha, 0);
+        beta = __shfl_sync(0xffffffff, beta, 0);
+
+        const T* vBlock = VBlocks[blockIdxKV];
+        const T* v = vBlock + blockOff * kvStride + kvHead * headDim;
+        #pragma unroll
+        for (int i = 0; i < 8; i++) {
+            const int d = lane + 32 * i;
+            if (d < headDim) {
+                acc[i] = acc[i] * alpha + beta * load_kv<T>(v, d);
+            }
+        }
+    }
+
+    if (lane == 0) {
+        partialMax[splitIdx] = m;
+        partialSum[splitIdx] = l;
+    }
+    #pragma unroll
+    for (int i = 0; i < 8; i++) {
+        const int d = lane + 32 * i;
+        if (d < headDim) {
+            outVec[d] = acc[i];
+        }
+    }
+}
+
+template <typename T, bool kUseRoPE = false>
+__global__ void paged_attention_split_kv_batch_kernel(
+    const float* Q,
+    const T* const* KBlocksFlat,
+    const T* const* VBlocksFlat,
+    const int* blockOffsets,
+    const int* kvLens,
+    const int* queryPos,
+    float* partialMax,
+    float* partialSum,
+    float* partialOut,
+    int numTokens,
+    int numHeads, int numKVHeads, int headDim,
+    int blockSize,
+    float scale,
+    int numSplits, int splitSize
+) {
+    const int tok = blockIdx.x;
+    const int head = blockIdx.y;
+    const int split = blockIdx.z;
+    const int lane = threadIdx.x & 31;
+    if (tok >= numTokens) {
+        return;
+    }
+    // One warp per (tok, head, split).
+    if (threadIdx.x >= 32) {
+        return;
+    }
+
+    const int kvHead = head / (numHeads / numKVHeads);
+    const float* q = Q + tok * numHeads * headDim + head * headDim;
+
+    const int kvLen = kvLens[tok];
+    const int qPos = queryPos[tok];
+    const int base = blockOffsets[tok];
+
+    const int kvStride = numKVHeads * headDim;
+    const int effectiveLen = (kvLen < (qPos + 1)) ? kvLen : (qPos + 1);
+
+    const int splitStart = split * splitSize;
+    const int splitEnd = (splitStart + splitSize < effectiveLen) ? (splitStart + splitSize) : effectiveLen;
+
+    const size_t qh = (size_t)tok * (size_t)numHeads + (size_t)head;
+    const size_t splitIdx = qh * (size_t)numSplits + (size_t)split;
+    float* outVec = partialOut + splitIdx * (size_t)headDim;
+
+    if (splitStart >= splitEnd) {
+        if (lane == 0) {
+            partialMax[splitIdx] = -INFINITY;
+            partialSum[splitIdx] = 0.0f;
+        }
+        #pragma unroll
+        for (int i = 0; i < 8; i++) {
+            const int d = lane + 32 * i;
+            if (d < headDim) {
+                outVec[d] = 0.0f;
+            }
+        }
+        return;
+    }
+
+    const int halfDim = headDim >> 1;
+    const int ropeExp = qPos - splitStart;
+
+    float qreg[8];
+    float q0[4];
+    float q1[4];
+    float cosStep[4];
+    float sinStep[4];
+    float cosDelta[4];
+    float sinDelta[4];
+    int pairCount = 0;
+    if (kUseRoPE) {
+        for (int j = lane; j < halfDim; j += 32) {
+            q0[pairCount] = q[j];
+            q1[pairCount] = q[j + halfDim];
+            cosStep[pairCount] = rope_cos_step_const[j];
+            sinStep[pairCount] = rope_sin_step_const[j];
+            const float2 baseStep = make_float2(cosStep[pairCount], sinStep[pairCount]);
+            const float2 ph = complex_pow_int(baseStep, ropeExp);
+            cosDelta[pairCount] = ph.x;
+            sinDelta[pairCount] = ph.y;
+            pairCount++;
+        }
+    } else {
+        #pragma unroll
+        for (int i = 0; i < 8; i++) {
+            const int d = lane + 32 * i;
+            qreg[i] = (d < headDim) ? q[d] : 0.0f;
+        }
+    }
+
+    float acc[8];
+    #pragma unroll
+    for (int i = 0; i < 8; i++) {
+        acc[i] = 0.0f;
+    }
+    float m = -INFINITY;
+    float l = 0.0f;
+
+    for (int kv = splitStart; kv < splitEnd; kv++) {
+        const int bidx = kv / blockSize;
+        const int boff = kv % blockSize;
+
+        const T* kBlock = KBlocksFlat[base + bidx];
+        const T* k = kBlock + boff * kvStride + kvHead * headDim;
+
+        float dot = 0.0f;
+        if (kUseRoPE) {
+            #pragma unroll
+            for (int pi = 0; pi < 4; pi++) {
+                if (pi >= pairCount) {
+                    break;
+                }
+                const int j = lane + 32 * pi;
+                if (j >= halfDim) {
+                    continue;
+                }
+                const float k0 = load_kv<T>(k, j);
+                const float k1 = load_kv<T>(k, j + halfDim);
+                const float a = fmaf(q0[pi], k0, q1[pi] * k1);
+                const float b = fmaf(q0[pi], k1, -q1[pi] * k0);
+                dot = fmaf(cosDelta[pi], a, dot);
+                dot = fmaf(sinDelta[pi], b, dot);
+            }
+        } else {
+            #pragma unroll
+            for (int i = 0; i < 8; i++) {
+                const int d = lane + 32 * i;
+                if (d < headDim) {
+                    dot = fmaf(qreg[i], load_kv<T>(k, d), dot);
+                }
+            }
+        }
+        for (int offset = 16; offset > 0; offset >>= 1) {
+            dot += __shfl_down_sync(0xffffffff, dot, offset);
+        }
+        dot = __shfl_sync(0xffffffff, dot, 0);
+
+        if (kUseRoPE) {
+            #pragma unroll
+            for (int pi = 0; pi < 4; pi++) {
+                if (pi >= pairCount) {
+                    break;
+                }
+                rope_advance_neg(cosDelta[pi], sinDelta[pi], cosStep[pi], sinStep[pi]);
+            }
+        }
+
+        const float score = dot * scale;
+        float alpha = 1.0f;
+        float beta = 0.0f;
+        if (lane == 0) {
+            const float newM = fmaxf(m, score);
+            alpha = __expf(m - newM);
+            beta = __expf(score - newM);
+            m = newM;
+            l = l * alpha + beta;
+        }
+        alpha = __shfl_sync(0xffffffff, alpha, 0);
+        beta = __shfl_sync(0xffffffff, beta, 0);
+
+        const T* vBlock = VBlocksFlat[base + bidx];
+        const T* v = vBlock + boff * kvStride + kvHead * headDim;
+        #pragma unroll
+        for (int i = 0; i < 8; i++) {
+            const int d = lane + 32 * i;
+            if (d < headDim) {
+                acc[i] = acc[i] * alpha + beta * load_kv<T>(v, d);
+            }
+        }
+    }
+
+    if (lane == 0) {
+        partialMax[splitIdx] = m;
+        partialSum[splitIdx] = l;
+    }
+    #pragma unroll
+    for (int i = 0; i < 8; i++) {
+        const int d = lane + 32 * i;
+        if (d < headDim) {
+            outVec[d] = acc[i];
+        }
+    }
+}
+
+__global__ void paged_attention_split_kv_reduce_kernel(
+    const float* partialMax,
+    const float* partialSum,
+    const float* partialOut,
+    float* out,
+    int queryCount,
+    int numHeads,
+    int headDim,
+    int numSplits
+) {
+    const int q = blockIdx.x;
+    const int head = blockIdx.y;
+    const int lane = threadIdx.x & 31;
+    if (q >= queryCount) {
+        return;
+    }
+    // One warp per (q, head).
+    if (threadIdx.x >= 32) {
+        return;
+    }
+
+    const size_t qh = (size_t)q * (size_t)numHeads + (size_t)head;
+    const size_t base = qh * (size_t)numSplits;
+
+    float gmax = -INFINITY;
+    if (lane == 0) {
+        for (int s = 0; s < numSplits; s++) {
+            gmax = fmaxf(gmax, partialMax[base + (size_t)s]);
+        }
+    }
+    gmax = __shfl_sync(0xffffffff, gmax, 0);
+    if (gmax == -INFINITY) {
+        const size_t outBase = qh * (size_t)headDim;
+        #pragma unroll
+        for (int i = 0; i < 8; i++) {
+            const int d = lane + 32 * i;
+            if (d < headDim) {
+                out[outBase + (size_t)d] = 0.0f;
+            }
+        }
+        return;
+    }
+
+    float sum = 0.0f;
+    if (lane == 0) {
+        for (int s = 0; s < numSplits; s++) {
+            const float l = partialSum[base + (size_t)s];
+            sum += l * __expf(partialMax[base + (size_t)s] - gmax);
+        }
+    }
+    sum = __shfl_sync(0xffffffff, sum, 0);
+    const float invSum = (sum > 0.0f) ? (1.0f / sum) : 0.0f;
+
+    float acc[8];
+    #pragma unroll
+    for (int i = 0; i < 8; i++) {
+        acc[i] = 0.0f;
+    }
+
+    for (int s = 0; s < numSplits; s++) {
+        float scale = 0.0f;
+        if (lane == 0) {
+            scale = __expf(partialMax[base + (size_t)s] - gmax);
+        }
+        scale = __shfl_sync(0xffffffff, scale, 0);
+
+        const float* vec = partialOut + (base + (size_t)s) * (size_t)headDim;
+        #pragma unroll
+        for (int i = 0; i < 8; i++) {
+            const int d = lane + 32 * i;
+            if (d < headDim) {
+                acc[i] = fmaf(scale, vec[d], acc[i]);
+            }
+        }
+    }
+
+    const size_t outBase = qh * (size_t)headDim;
+    #pragma unroll
+    for (int i = 0; i < 8; i++) {
+        const int d = lane + 32 * i;
+        if (d < headDim) {
+            out[outBase + (size_t)d] = acc[i] * invSum;
+        }
+    }
+}
+
+int cuda_rmsnorm_f32(float* x, const float* w, int seqLen, int dim, float eps) {
+    int threads = 256;
+    rmsnorm_kernel<<<seqLen, threads>>>(x, w, dim, eps);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+int cuda_cast_f32_to_f16(const float* src, unsigned short* dst, int n) {
+    int threads = 256;
+    int blocks = (n + threads - 1) / threads;
+    cast_f32_to_f16_kernel<<<blocks, threads>>>(src, reinterpret_cast<__half*>(dst), n);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+int cuda_paged_attention_f32_f16kv(
+    const float* Q,
+    const unsigned short* const* KBlocks,
+    const unsigned short* const* VBlocks,
+    float* out,
+    int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
+    int blockSize,
+    float scale, int startPos
+) {
+    // Split-KV Flash Decoding for long contexts.
+    const int maxEffectiveLen = (kvLen < (startPos + seqLen)) ? kvLen : (startPos + seqLen);
+    const int numSplits = (maxEffectiveLen + kPagedAttentionSplitSize - 1) / kPagedAttentionSplitSize;
+    const int qhCount = seqLen * numHeads;
+    const bool useSplit = (headDim <= 256) && (numSplits > 1) && (qhCount < kPagedAttentionSplitQHThreshold);
+
+    if (!useSplit) {
+        dim3 blocks(seqLen, numHeads);
+        int threads = 32;
+        paged_attention_kernel_f16kv<<<blocks, threads>>>(
+            Q, KBlocks, VBlocks, out,
+            seqLen, kvLen, numHeads, numKVHeads, headDim,
+            blockSize, scale, startPos
+        );
+        CHECK_CUDA(cudaGetLastError());
+        return 0;
+    }
+
+    const size_t splitCount = (size_t)qhCount * (size_t)numSplits;
+    const size_t totalFloats = splitCount * (size_t)(headDim + 2);
+    float* buf = reinterpret_cast<float*>(cuda_malloc(totalFloats * sizeof(float)));
+    if (buf == nullptr) {
+        return 1;
+    }
+    float* partialMax = buf;
+    float* partialSum = partialMax + splitCount;
+    float* partialOut = partialSum + splitCount;
+
+    dim3 blocks1(seqLen, numHeads, numSplits);
+    paged_attention_split_kv_kernel<__half><<<blocks1, 32>>>(
+        Q,
+        reinterpret_cast<const __half* const*>(KBlocks),
+        reinterpret_cast<const __half* const*>(VBlocks),
+        partialMax, partialSum, partialOut,
+        seqLen, kvLen, numHeads, numKVHeads, headDim,
+        blockSize,
+        scale, startPos,
+        numSplits, kPagedAttentionSplitSize
+    );
+    CHECK_CUDA(cudaGetLastError());
+
+    dim3 blocks2(seqLen, numHeads);
+    paged_attention_split_kv_reduce_kernel<<<blocks2, 32>>>(
+        partialMax, partialSum, partialOut,
+        out,
+        seqLen, numHeads, headDim, numSplits
+    );
+    CHECK_CUDA(cudaGetLastError());
+
+    cuda_free(buf);
+    return 0;
+}
+
+int cuda_paged_attention_rope_f32_f16kv(
+    const float* Q,
+    const unsigned short* const* KBlocks,
+    const unsigned short* const* VBlocks,
+    float* out,
+    int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
+    int blockSize,
+    float scale, int startPos,
+    float theta
+) {
+    if (ensure_rope_step_table(headDim, theta) != 0) {
+        return 1;
+    }
+
+    // Split-KV Flash Decoding for long contexts.
+    const int maxEffectiveLen = (kvLen < (startPos + seqLen)) ? kvLen : (startPos + seqLen);
+    const int numSplits = (maxEffectiveLen + kPagedAttentionSplitSize - 1) / kPagedAttentionSplitSize;
+    const int qhCount = seqLen * numHeads;
+    const bool useSplit = (headDim <= 256) && (numSplits > 1) && (qhCount < kPagedAttentionSplitQHThreshold);
+
+    if (!useSplit) {
+        dim3 blocks(seqLen, numHeads);
+        int threads = 32;
+        paged_attention_kernel_f16kv_rope<<<blocks, threads>>>(
+            Q, KBlocks, VBlocks, out,
+            seqLen, kvLen, numHeads, numKVHeads, headDim,
+            blockSize, scale, startPos
+        );
+        CHECK_CUDA(cudaGetLastError());
+        return 0;
+    }
+
+    const size_t splitCount = (size_t)qhCount * (size_t)numSplits;
+    const size_t totalFloats = splitCount * (size_t)(headDim + 2);
+    float* buf = reinterpret_cast<float*>(cuda_malloc(totalFloats * sizeof(float)));
+    if (buf == nullptr) {
+        return 1;
+    }
+    float* partialMax = buf;
+    float* partialSum = partialMax + splitCount;
+    float* partialOut = partialSum + splitCount;
+
+    dim3 blocks1(seqLen, numHeads, numSplits);
+    paged_attention_split_kv_kernel<__half, true><<<blocks1, 32>>>(
+        Q,
+        reinterpret_cast<const __half* const*>(KBlocks),
+        reinterpret_cast<const __half* const*>(VBlocks),
+        partialMax, partialSum, partialOut,
+        seqLen, kvLen, numHeads, numKVHeads, headDim,
+        blockSize,
+        scale, startPos,
+        numSplits, kPagedAttentionSplitSize
+    );
+    CHECK_CUDA(cudaGetLastError());
+
+    dim3 blocks2(seqLen, numHeads);
+    paged_attention_split_kv_reduce_kernel<<<blocks2, 32>>>(
+        partialMax, partialSum, partialOut,
+        out,
+        seqLen, numHeads, headDim, numSplits
+    );
+    CHECK_CUDA(cudaGetLastError());
+
+    cuda_free(buf);
+    return 0;
+}
+
+int cuda_paged_attention_batch_f32_f16kv(
+    const float* Q,
+    const unsigned short* const* KBlocksFlat,
+    const unsigned short* const* VBlocksFlat,
+    const int* blockOffsets,
+    const int* kvLens,
+    const int* queryPos,
+    float* out,
+    int numTokens,
+    int numHeads, int numKVHeads, int headDim,
+    int blockSize,
+    float scale,
+    int maxKvLen
+) {
+    const int numSplits = (maxKvLen + kPagedAttentionSplitSize - 1) / kPagedAttentionSplitSize;
+    const int qhCount = numTokens * numHeads;
+    const bool useSplit = (headDim <= 256) && (numSplits > 1) && (qhCount < kPagedAttentionSplitQHThreshold);
+
+    if (!useSplit) {
+        dim3 blocks(numTokens, numHeads);
+        int threads = 32;
+        paged_attention_batch_kernel_f16kv<<<blocks, threads>>>(
+            Q, KBlocksFlat, VBlocksFlat,
+            blockOffsets, kvLens, queryPos,
+            out,
+            numTokens,
+            numHeads, numKVHeads, headDim,
+            blockSize,
+            scale
+        );
+        CHECK_CUDA(cudaGetLastError());
+        return 0;
+    }
+
+    const size_t splitCount = (size_t)qhCount * (size_t)numSplits;
+    const size_t totalFloats = splitCount * (size_t)(headDim + 2);
+    float* buf = reinterpret_cast<float*>(cuda_malloc(totalFloats * sizeof(float)));
+    if (buf == nullptr) {
+        return 1;
+    }
+    float* partialMax = buf;
+    float* partialSum = partialMax + splitCount;
+    float* partialOut = partialSum + splitCount;
+
+    dim3 blocks1(numTokens, numHeads, numSplits);
+    paged_attention_split_kv_batch_kernel<__half><<<blocks1, 32>>>(
+        Q,
+        reinterpret_cast<const __half* const*>(KBlocksFlat),
+        reinterpret_cast<const __half* const*>(VBlocksFlat),
+        blockOffsets, kvLens, queryPos,
+        partialMax, partialSum, partialOut,
+        numTokens,
+        numHeads, numKVHeads, headDim,
+        blockSize,
+        scale,
+        numSplits, kPagedAttentionSplitSize
+    );
+    CHECK_CUDA(cudaGetLastError());
+
+    dim3 blocks2(numTokens, numHeads);
+    paged_attention_split_kv_reduce_kernel<<<blocks2, 32>>>(
+        partialMax, partialSum, partialOut,
+        out,
+        numTokens, numHeads, headDim, numSplits
+    );
+    CHECK_CUDA(cudaGetLastError());
+
+    cuda_free(buf);
+    return 0;
+}
+
+int cuda_paged_attention_rope_batch_f32_f16kv(
+    const float* Q,
+    const unsigned short* const* KBlocksFlat,
+    const unsigned short* const* VBlocksFlat,
+    const int* blockOffsets,
+    const int* kvLens,
+    const int* queryPos,
+    float* out,
+    int numTokens,
+    int numHeads, int numKVHeads, int headDim,
+    int blockSize,
+    float scale,
+    int maxKvLen,
+    float theta
+) {
+    if (ensure_rope_step_table(headDim, theta) != 0) {
+        return 1;
+    }
+
+    const int numSplits = (maxKvLen + kPagedAttentionSplitSize - 1) / kPagedAttentionSplitSize;
+    const int qhCount = numTokens * numHeads;
+    const bool useSplit = (headDim <= 256) && (numSplits > 1) && (qhCount < kPagedAttentionSplitQHThreshold);
+
+    if (!useSplit) {
+        dim3 blocks(numTokens, numHeads);
+        int threads = 32;
+        paged_attention_batch_kernel_f16kv_rope<<<blocks, threads>>>(
+            Q, KBlocksFlat, VBlocksFlat,
+            blockOffsets, kvLens, queryPos,
+            out,
+            numTokens,
+            numHeads, numKVHeads, headDim,
+            blockSize,
+            scale
+        );
+        CHECK_CUDA(cudaGetLastError());
+        return 0;
+    }
+
+    const size_t splitCount = (size_t)qhCount * (size_t)numSplits;
+    const size_t totalFloats = splitCount * (size_t)(headDim + 2);
+    float* buf = reinterpret_cast<float*>(cuda_malloc(totalFloats * sizeof(float)));
+    if (buf == nullptr) {
+        return 1;
+    }
+    float* partialMax = buf;
+    float* partialSum = partialMax + splitCount;
+    float* partialOut = partialSum + splitCount;
+
+    dim3 blocks1(numTokens, numHeads, numSplits);
+    paged_attention_split_kv_batch_kernel<__half, true><<<blocks1, 32>>>(
+        Q,
+        reinterpret_cast<const __half* const*>(KBlocksFlat),
+        reinterpret_cast<const __half* const*>(VBlocksFlat),
+        blockOffsets, kvLens, queryPos,
+        partialMax, partialSum, partialOut,
+        numTokens,
+        numHeads, numKVHeads, headDim,
+        blockSize,
+        scale,
+        numSplits, kPagedAttentionSplitSize
+    );
+    CHECK_CUDA(cudaGetLastError());
+
+    dim3 blocks2(numTokens, numHeads);
+    paged_attention_split_kv_reduce_kernel<<<blocks2, 32>>>(
+        partialMax, partialSum, partialOut,
+        out,
+        numTokens, numHeads, headDim, numSplits
+    );
+    CHECK_CUDA(cudaGetLastError());
+
+    cuda_free(buf);
+    return 0;
+}
+
+__global__ void paged_attention_batch_kernel(
+    const float* Q,
+    const float* const* KBlocksFlat,
+    const float* const* VBlocksFlat,
+    const int* blockOffsets,
+    const int* kvLens,
+    const int* queryPos,
+    float* out,
+    int numTokens,
+    int numHeads, int numKVHeads, int headDim,
+    int blockSize,
+    float scale
+);
+
+__global__ void paged_attention_batch_kernel_f16kv(
+    const float* Q,
+    const unsigned short* const* KBlocksFlat,
+    const unsigned short* const* VBlocksFlat,
+    const int* blockOffsets,
+    const int* kvLens,
+    const int* queryPos,
+    float* out,
+    int numTokens,
+    int numHeads, int numKVHeads, int headDim,
+    int blockSize,
+    float scale
+);
+
+int cuda_paged_attention_batch_f32(
+    const float* Q,
+    const float* const* KBlocksFlat,
+    const float* const* VBlocksFlat,
+    const int* blockOffsets,
+    const int* kvLens,
+    const int* queryPos,
+    float* out,
+    int numTokens,
+    int numHeads, int numKVHeads, int headDim,
+    int blockSize,
+    float scale,
+    int maxKvLen
+) {
+    const int numSplits = (maxKvLen + kPagedAttentionSplitSize - 1) / kPagedAttentionSplitSize;
+    const int qhCount = numTokens * numHeads;
+    const bool useSplit = (headDim <= 256) && (numSplits > 1) && (qhCount < kPagedAttentionSplitQHThreshold);
+
+    if (!useSplit) {
+        dim3 blocks(numTokens, numHeads);
+        int threads = 256;
+        paged_attention_batch_kernel<<<blocks, threads>>>(
+            Q, KBlocksFlat, VBlocksFlat,
+            blockOffsets, kvLens, queryPos,
+            out,
+            numTokens,
+            numHeads, numKVHeads, headDim,
+            blockSize,
+            scale
+        );
+        CHECK_CUDA(cudaGetLastError());
+        return 0;
+    }
+
+    const size_t splitCount = (size_t)qhCount * (size_t)numSplits;
+    const size_t totalFloats = splitCount * (size_t)(headDim + 2);
+    float* buf = reinterpret_cast<float*>(cuda_malloc(totalFloats * sizeof(float)));
+    if (buf == nullptr) {
+        return 1;
+    }
+    float* partialMax = buf;
+    float* partialSum = partialMax + splitCount;
+    float* partialOut = partialSum + splitCount;
+
+    dim3 blocks1(numTokens, numHeads, numSplits);
+    paged_attention_split_kv_batch_kernel<float><<<blocks1, 32>>>(
+        Q,
+        KBlocksFlat,
+        VBlocksFlat,
+        blockOffsets, kvLens, queryPos,
+        partialMax, partialSum, partialOut,
+        numTokens,
+        numHeads, numKVHeads, headDim,
+        blockSize,
+        scale,
+        numSplits, kPagedAttentionSplitSize
+    );
+    CHECK_CUDA(cudaGetLastError());
+
+    dim3 blocks2(numTokens, numHeads);
+    paged_attention_split_kv_reduce_kernel<<<blocks2, 32>>>(
+        partialMax, partialSum, partialOut,
+        out,
+        numTokens, numHeads, headDim, numSplits
+    );
+    CHECK_CUDA(cudaGetLastError());
+
+    cuda_free(buf);
+    return 0;
+}
+
+// RoPE kernel
+__global__ void rope_kernel(float* x, const int* positions, int totalDim, int headDim, float theta) {
+    int seq = blockIdx.x;
+    int pos = positions[seq];
+    float* rowData = x + seq * totalDim;
+    int halfDim = headDim / 2;
+    
+    // Each thread handles one (j, j+halfDim) pair across all heads.
+    // Compute sin/cos once per j and reuse across heads.
+    for (int j = threadIdx.x; j < halfDim; j += blockDim.x) {
+        const float invFreq = 1.0f / powf(theta, 2.0f * j / headDim);
+        const float freq = pos * invFreq;
+        float sinF, cosF;
+        sincosf(freq, &sinF, &cosF);
+
+        for (int headStart = 0; headStart < totalDim; headStart += headDim) {
+            const int idx0 = headStart + j;
+            const int idx1 = idx0 + halfDim;
+
+            const float v0 = rowData[idx0];
+            const float v1 = rowData[idx1];
+
+            rowData[idx0] = v0 * cosF - v1 * sinF;
+            rowData[idx1] = v1 * cosF + v0 * sinF;
+        }
+    }
+}
+
+int cuda_rope_f32(float* x, const int* positions, int seqLen, int numHeads, int headDim, float theta) {
+    int threads = (headDim / 2) < 128 ? (headDim / 2) : 128;
+    rope_kernel<<<seqLen, threads>>>(x, positions, numHeads * headDim, headDim, theta);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+// Optimized RoPE kernel for single position (scalar pos)
+__global__ void rope_kernel_single(float* x, int pos, int totalDim, int headDim, float theta) {
+    // seq is always 0 for single token
+    float* rowData = x; // x + 0 * totalDim
+    int halfDim = headDim / 2;
+    
+    // Each thread handles one (j, j+halfDim) pair across all heads.
+    // Compute sin/cos once per j and reuse across heads.
+    for (int j = threadIdx.x; j < halfDim; j += blockDim.x) {
+        const float invFreq = 1.0f / powf(theta, 2.0f * j / headDim);
+        const float freq = pos * invFreq;
+        float sinF, cosF;
+        sincosf(freq, &sinF, &cosF);
+
+        for (int headStart = 0; headStart < totalDim; headStart += headDim) {
+            const int idx0 = headStart + j;
+            const int idx1 = idx0 + halfDim;
+
+            const float v0 = rowData[idx0];
+            const float v1 = rowData[idx1];
+
+            rowData[idx0] = v0 * cosF - v1 * sinF;
+            rowData[idx1] = v1 * cosF + v0 * sinF;
+        }
+    }
+}
+
+int cuda_rope_f32_single(float* x, int pos, int numHeads, int headDim, float theta) {
+    int threads = (headDim / 2) < 128 ? (headDim / 2) : 128;
+    rope_kernel_single<<<1, threads>>>(x, pos, numHeads * headDim, headDim, theta);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+// Softmax kernel: one block per row
+__global__ void softmax_kernel(float* x, int cols) {
+    int row = blockIdx.x;
+    float* rowData = x + row * cols;
+    
+    __shared__ float smax[256];
+    __shared__ float ssum[256];
+    
+    // Find max
+    float threadMax = -INFINITY;
+    for (int i = threadIdx.x; i < cols; i += blockDim.x) {
+        threadMax = fmaxf(threadMax, rowData[i]);
+    }
+    smax[threadIdx.x] = threadMax;
+    __syncthreads();
+    
+    for (int s = blockDim.x / 2; s > 0; s >>= 1) {
+        if (threadIdx.x < s) {
+            smax[threadIdx.x] = fmaxf(smax[threadIdx.x], smax[threadIdx.x + s]);
+        }
+        __syncthreads();
+    }
+    float maxVal = smax[0];
+    
+    // Compute exp and sum
+    float threadSum = 0.0f;
+    for (int i = threadIdx.x; i < cols; i += blockDim.x) {
+        float val = expf(rowData[i] - maxVal);
+        rowData[i] = val;
+        threadSum += val;
+    }
+    ssum[threadIdx.x] = threadSum;
+    __syncthreads();
+    
+    for (int s = blockDim.x / 2; s > 0; s >>= 1) {
+        if (threadIdx.x < s) {
+            ssum[threadIdx.x] += ssum[threadIdx.x + s];
+        }
+        __syncthreads();
+    }
+    float sum = ssum[0];
+    
+    // Normalize
+    for (int i = threadIdx.x; i < cols; i += blockDim.x) {
+        rowData[i] /= sum;
+    }
+}
+
+int cuda_softmax_f32(float* x, int rows, int cols) {
+    int threads = 256;
+    softmax_kernel<<<rows, threads>>>(x, cols);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+// ============================================================
+// Top-K Logits Selection (for sampling without full D2H)
+// ============================================================
+
+#define TOPK_THREADS 256
+#define TOPK_LOCAL 8
+#define TOPK_SEGMENT (TOPK_THREADS * TOPK_LOCAL) // 2048
+#define TOPK_MAX_K 64
+
+static __device__ __forceinline__ float apply_rep_penalty(float score, bool hit, float penalty) {
+    if (!hit) return score;
+    if (score > 0.0f) return score / penalty;
+    return score * penalty;
+}
+
+__global__ void topk_logits_kernel(
+    const float* logits, int vocab,
+    const int* rep_ids, int rep_count, float rep_penalty,
+    int k,
+    int* out_ids, float* out_scores
+) {
+    const int blockStart = blockIdx.x * TOPK_SEGMENT;
+    const int tid = threadIdx.x;
+    if (tid >= TOPK_THREADS) return;
+
+    float localScores[TOPK_LOCAL];
+    int localIds[TOPK_LOCAL];
+    #pragma unroll
+    for (int i = 0; i < TOPK_LOCAL; i++) {
+        localScores[i] = -INFINITY;
+        localIds[i] = -1;
+    }
+
+    // Each thread processes up to TOPK_LOCAL elements in a contiguous segment.
+    #pragma unroll
+    for (int j = 0; j < TOPK_LOCAL; j++) {
+        int idx = blockStart + tid + j * TOPK_THREADS;
+        if (idx >= vocab) break;
+        float score = logits[idx];
+        bool hit = false;
+        // rep_count is small (<=64). Linear scan is fine.
+        for (int r = 0; r < rep_count; r++) {
+            if (rep_ids[r] == idx) {
+                hit = true;
+                break;
+            }
+        }
+        score = apply_rep_penalty(score, hit, rep_penalty);
+
+        // Insert into local top list (descending)
+        int pos = TOPK_LOCAL;
+        for (int t = 0; t < TOPK_LOCAL; t++) {
+            if (score > localScores[t]) { pos = t; break; }
+        }
+        if (pos < TOPK_LOCAL) {
+            for (int t = TOPK_LOCAL - 1; t > pos; t--) {
+                localScores[t] = localScores[t-1];
+                localIds[t] = localIds[t-1];
+            }
+            localScores[pos] = score;
+            localIds[pos] = idx;
+        }
+    }
+
+    __shared__ float shScores[TOPK_SEGMENT];
+    __shared__ int shIds[TOPK_SEGMENT];
+
+    // Write local results to shared candidate pool.
+    #pragma unroll
+    for (int j = 0; j < TOPK_LOCAL; j++) {
+        int out = tid * TOPK_LOCAL + j;
+        shScores[out] = localScores[j];
+        shIds[out] = localIds[j];
+    }
+    __syncthreads();
+
+    if (tid == 0) {
+        // Block-level exact top-k from TOPK_SEGMENT candidates.
+        if (k > TOPK_MAX_K) k = TOPK_MAX_K;
+        float bestScores[TOPK_MAX_K];
+        int bestIds[TOPK_MAX_K];
+        for (int i = 0; i < k; i++) {
+            bestScores[i] = -INFINITY;
+            bestIds[i] = -1;
+        }
+        for (int i = 0; i < TOPK_SEGMENT; i++) {
+            float score = shScores[i];
+            int id = shIds[i];
+            if (id < 0) continue;
+            // Insert into best (descending)
+            if (score <= bestScores[k-1]) continue;
+            int pos = k;
+            for (int t = 0; t < k; t++) {
+                if (score > bestScores[t]) { pos = t; break; }
+            }
+            if (pos < k) {
+                for (int t = k - 1; t > pos; t--) {
+                    bestScores[t] = bestScores[t-1];
+                    bestIds[t] = bestIds[t-1];
+                }
+                bestScores[pos] = score;
+                bestIds[pos] = id;
+            }
+        }
+        int base = blockIdx.x * k;
+        for (int i = 0; i < k; i++) {
+            out_ids[base + i] = bestIds[i];
+            out_scores[base + i] = bestScores[i];
+        }
+    }
+}
+
+int cuda_topk_logits_f32(
+    const float* logits, int vocab,
+    const int* rep_ids, int rep_count, float rep_penalty,
+    int k,
+    int* out_ids, float* out_scores
+) {
+    if (k <= 0) return 0;
+    if (k > TOPK_MAX_K) k = TOPK_MAX_K;
+    int blocks = (vocab + TOPK_SEGMENT - 1) / TOPK_SEGMENT;
+    dim3 grid(blocks);
+    dim3 block(TOPK_THREADS);
+    topk_logits_kernel<<<grid, block>>>(logits, vocab, rep_ids, rep_count, rep_penalty, k, out_ids, out_scores);
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+// ============================================================
+// Attention Kernel
+// Computes: softmax(Q @ K.T / scale + causal_mask) @ V
+// ============================================================
+__global__ void attention_kernel(
+    const float* Q, const float* K, const float* V, float* out,
+    int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
+    float scale, int startPos
+) {
+    // Each block handles one (seq, head) pair
+    int seq = blockIdx.x;
+    int head = blockIdx.y;
+    
+    int kvHead = head / (numHeads / numKVHeads); // GQA support
+    
+    const float* q = Q + seq * numHeads * headDim + head * headDim;
+    float* o = out + seq * numHeads * headDim + head * headDim;
+    
+    // Shared memory for attention scores
+    extern __shared__ float shared[];
+    float* scores = shared; // [kvLen]
+    
+    // Compute Q @ K.T for this head
+    for (int kv = threadIdx.x; kv < kvLen; kv += blockDim.x) {
+        const float* k = K + kv * numKVHeads * headDim + kvHead * headDim;
+        
+        float dot = 0.0f;
+        for (int d = 0; d < headDim; d++) {
+            dot += q[d] * k[d];
+        }
+        
+        // Apply causal mask
+        int queryPos = startPos + seq;
+        int keyPos = kv;
+        if (keyPos > queryPos) {
+            dot = -INFINITY;
+        }
+        
+        scores[kv] = dot * scale;
+    }
+    __syncthreads();
+    
+    // Softmax over scores
+    // Find max
+    float maxVal = -INFINITY;
+    for (int kv = threadIdx.x; kv < kvLen; kv += blockDim.x) {
+        maxVal = fmaxf(maxVal, scores[kv]);
+    }
+    
+    // Reduce max across threads
+    __shared__ float smax[256];
+    smax[threadIdx.x] = maxVal;
+    __syncthreads();
+    for (int s = blockDim.x / 2; s > 0; s >>= 1) {
+        if (threadIdx.x < s) {
+            smax[threadIdx.x] = fmaxf(smax[threadIdx.x], smax[threadIdx.x + s]);
+        }
+        __syncthreads();
+    }
+    maxVal = smax[0];
+    
+    // Exp and sum
+    float sum = 0.0f;
+    for (int kv = threadIdx.x; kv < kvLen; kv += blockDim.x) {
+        float val = expf(scores[kv] - maxVal);
+        scores[kv] = val;
+        sum += val;
+    }
+    
+    __shared__ float ssum[256];
+    ssum[threadIdx.x] = sum;
+    __syncthreads();
+    for (int s = blockDim.x / 2; s > 0; s >>= 1) {
+        if (threadIdx.x < s) {
+            ssum[threadIdx.x] += ssum[threadIdx.x + s];
+        }
+        __syncthreads();
+    }
+    sum = ssum[0];
+    
+    // Normalize
+    for (int kv = threadIdx.x; kv < kvLen; kv += blockDim.x) {
+        scores[kv] /= sum;
+    }
+    __syncthreads();
+    
+    // Compute weighted sum of V
+    for (int d = threadIdx.x; d < headDim; d += blockDim.x) {
+        float val = 0.0f;
+        for (int kv = 0; kv < kvLen; kv++) {
+            const float* v = V + kv * numKVHeads * headDim + kvHead * headDim;
+            val += scores[kv] * v[d];
+        }
+        o[d] = val;
+    }
+}
+
+__global__ void paged_attention_kernel(
+    const float* Q,
+    const float* const* KBlocks,
+    const float* const* VBlocks,
+    float* out,
+    int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
+    int blockSize,
+    float scale, int startPos
+) {
+    int seq = blockIdx.x;
+    int head = blockIdx.y;
+
+    int kvHead = head / (numHeads / numKVHeads);
+
+    const float* q = Q + seq * numHeads * headDim + head * headDim;
+    float* o = out + seq * numHeads * headDim + head * headDim;
+
+    const int kvStride = numKVHeads * headDim;
+
+    int queryPos = startPos + seq;
+
+    // Pass 1: max
+    float localMax = -INFINITY;
+    for (int kv = threadIdx.x; kv < kvLen; kv += blockDim.x) {
+        int keyPos = kv;
+        if (keyPos > queryPos) {
+            continue;
+        }
+        const int blockIdxKV = kv / blockSize;
+        const int blockOff = kv % blockSize;
+        const float* kBlock = KBlocks[blockIdxKV];
+        const float* k = kBlock + blockOff * kvStride + kvHead * headDim;
+
+        float dot = 0.0f;
+        for (int d = 0; d < headDim; d++) {
+            dot += q[d] * k[d];
+        }
+        float score = dot * scale;
+        localMax = fmaxf(localMax, score);
+    }
+
+    __shared__ float smax[256];
+    smax[threadIdx.x] = localMax;
+    __syncthreads();
+    for (int s = blockDim.x / 2; s > 0; s >>= 1) {
+        if (threadIdx.x < s) {
+            smax[threadIdx.x] = fmaxf(smax[threadIdx.x], smax[threadIdx.x + s]);
+        }
+        __syncthreads();
+    }
+    float maxVal = smax[0];
+
+    // Pass 2: sum exp
+    float localSum = 0.0f;
+    for (int kv = threadIdx.x; kv < kvLen; kv += blockDim.x) {
+        int keyPos = kv;
+        if (keyPos > queryPos) {
+            continue;
+        }
+        const int blockIdxKV = kv / blockSize;
+        const int blockOff = kv % blockSize;
+        const float* kBlock = KBlocks[blockIdxKV];
+        const float* k = kBlock + blockOff * kvStride + kvHead * headDim;
+
+        float dot = 0.0f;
+        for (int d = 0; d < headDim; d++) {
+            dot += q[d] * k[d];
+        }
+        float score = dot * scale;
+        localSum += expf(score - maxVal);
+    }
+
+    __shared__ float ssum[256];
+    ssum[threadIdx.x] = localSum;
+    __syncthreads();
+    for (int s = blockDim.x / 2; s > 0; s >>= 1) {
+        if (threadIdx.x < s) {
+            ssum[threadIdx.x] += ssum[threadIdx.x + s];
+        }
+        __syncthreads();
+    }
+    float sumVal = ssum[0];
+    float invSum = (sumVal > 0.0f) ? (1.0f / sumVal) : 0.0f;
+
+    // Pass 3: output accumulation
+    for (int d = threadIdx.x; d < headDim; d += blockDim.x) {
+        float outVal = 0.0f;
+        for (int kv = 0; kv < kvLen; kv++) {
+            int keyPos = kv;
+            if (keyPos > queryPos) {
+                break;
+            }
+            const int blockIdxKV = kv / blockSize;
+            const int blockOff = kv % blockSize;
+            const float* kBlock = KBlocks[blockIdxKV];
+            const float* k = kBlock + blockOff * kvStride + kvHead * headDim;
+
+            float dot = 0.0f;
+            for (int kd = 0; kd < headDim; kd++) {
+                dot += q[kd] * k[kd];
+            }
+            float score = dot * scale;
+            float w = expf(score - maxVal) * invSum;
+
+            const float* vBlock = VBlocks[blockIdxKV];
+            const float* v = vBlock + blockOff * kvStride + kvHead * headDim;
+            outVal += w * v[d];
+        }
+        o[d] = outVal;
+    }
+}
+
+__global__ void paged_attention_batch_kernel(
+    const float* Q,
+    const float* const* KBlocksFlat,
+    const float* const* VBlocksFlat,
+    const int* blockOffsets,
+    const int* kvLens,
+    const int* queryPos,
+    float* out,
+    int numTokens,
+    int numHeads, int numKVHeads, int headDim,
+    int blockSize,
+    float scale
+) {
+    int tok = blockIdx.x;
+    int head = blockIdx.y;
+    if (tok >= numTokens) {
+        return;
+    }
+
+    int kvHead = head / (numHeads / numKVHeads);
+
+    const float* q = Q + tok * numHeads * headDim + head * headDim;
+    float* o = out + tok * numHeads * headDim + head * headDim;
+
+    int kvLen = kvLens[tok];
+    int qPos = queryPos[tok];
+    int base = blockOffsets[tok];
+
+    const int kvStride = numKVHeads * headDim;
+    const int effectiveLen = (kvLen < (qPos + 1)) ? kvLen : (qPos + 1);
+
+    float acc = 0.0f;
+    if (threadIdx.x >= headDim) {
+        acc = 0.0f;
+    }
+
+    __shared__ float m;
+    __shared__ float l;
+    __shared__ float alpha;
+    __shared__ float beta;
+    __shared__ float dotShared;
+
+    if (threadIdx.x == 0) {
+        m = -INFINITY;
+        l = 0.0f;
+    }
+    __syncthreads();
+
+    for (int kv = 0; kv < effectiveLen; kv++) {
+        const int bidx = kv / blockSize;
+        const int boff = kv % blockSize;
+
+        const float* kBlock = KBlocksFlat[base + bidx];
+        const float* k = kBlock + boff * kvStride + kvHead * headDim;
+
+        float partial = 0.0f;
+        for (int d = threadIdx.x; d < headDim; d += blockDim.x) {
+            partial = fmaf(q[d], k[d], partial);
+        }
+
+        // block reduction (sum)
+        for (int offset = 16; offset > 0; offset >>= 1) {
+            partial += __shfl_down_sync(0xffffffff, partial, offset);
+        }
+        __shared__ float warpSum[8];
+        int lane = threadIdx.x & 31;
+        int warp = threadIdx.x >> 5;
+        if (lane == 0) {
+            warpSum[warp] = partial;
+        }
+        __syncthreads();
+        if (warp == 0) {
+            float v = (lane < 8) ? warpSum[lane] : 0.0f;
+            for (int offset = 16; offset > 0; offset >>= 1) {
+                v += __shfl_down_sync(0xffffffff, v, offset);
+            }
+            if (lane == 0) {
+                dotShared = v;
+            }
+        }
+        __syncthreads();
+
+        float score = dotShared * scale;
+
+        if (threadIdx.x == 0) {
+            float newM = fmaxf(m, score);
+            float a = expf(m - newM);
+            float b = expf(score - newM);
+            m = newM;
+            l = l * a + b;
+            alpha = a;
+            beta = b;
+        }
+        __syncthreads();
+
+        if (threadIdx.x < headDim) {
+            const float* vBlock = VBlocksFlat[base + bidx];
+            const float* v = vBlock + boff * kvStride + kvHead * headDim;
+            acc = fmaf(beta, v[threadIdx.x], acc * alpha);
+        }
+        __syncthreads();
+    }
+
+    if (threadIdx.x < headDim) {
+        float invL = (l > 0.0f) ? (1.0f / l) : 0.0f;
+        o[threadIdx.x] = acc * invL;
+    }
+}
+
+int cuda_attention_f32(
+    const float* Q, const float* K, const float* V, float* out,
+    int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
+    float scale, int startPos
+) {
+    dim3 blocks(seqLen, numHeads);
+    int threads = 256;
+    size_t sharedMem = kvLen * sizeof(float);
+    
+    attention_kernel<<<blocks, threads, sharedMem>>>(
+        Q, K, V, out, seqLen, kvLen, numHeads, numKVHeads, headDim, scale, startPos
+    );
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+int cuda_paged_attention_f32(
+    const float* Q,
+    const float* const* KBlocks,
+    const float* const* VBlocks,
+    float* out,
+    int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
+    int blockSize,
+    float scale, int startPos
+) {
+    // Split-KV Flash Decoding for long contexts.
+    const int maxEffectiveLen = (kvLen < (startPos + seqLen)) ? kvLen : (startPos + seqLen);
+    const int numSplits = (maxEffectiveLen + kPagedAttentionSplitSize - 1) / kPagedAttentionSplitSize;
+    const int qhCount = seqLen * numHeads;
+    const bool useSplit = (headDim <= 256) && (numSplits > 1) && (qhCount < kPagedAttentionSplitQHThreshold);
+
+    if (!useSplit) {
+        dim3 blocks(seqLen, numHeads);
+        int threads = 256;
+        paged_attention_kernel<<<blocks, threads>>>(
+            Q, KBlocks, VBlocks, out,
+            seqLen, kvLen, numHeads, numKVHeads, headDim,
+            blockSize, scale, startPos
+        );
+        CHECK_CUDA(cudaGetLastError());
+        return 0;
+    }
+
+    const size_t splitCount = (size_t)qhCount * (size_t)numSplits;
+    const size_t totalFloats = splitCount * (size_t)(headDim + 2);
+    float* buf = reinterpret_cast<float*>(cuda_malloc(totalFloats * sizeof(float)));
+    if (buf == nullptr) {
+        return 1;
+    }
+    float* partialMax = buf;
+    float* partialSum = partialMax + splitCount;
+    float* partialOut = partialSum + splitCount;
+
+    dim3 blocks1(seqLen, numHeads, numSplits);
+    paged_attention_split_kv_kernel<float><<<blocks1, 32>>>(
+        Q, KBlocks, VBlocks,
+        partialMax, partialSum, partialOut,
+        seqLen, kvLen, numHeads, numKVHeads, headDim,
+        blockSize,
+        scale, startPos,
+        numSplits, kPagedAttentionSplitSize
+    );
+    CHECK_CUDA(cudaGetLastError());
+
+    dim3 blocks2(seqLen, numHeads);
+    paged_attention_split_kv_reduce_kernel<<<blocks2, 32>>>(
+        partialMax, partialSum, partialOut,
+        out,
+        seqLen, numHeads, headDim, numSplits
+    );
+    CHECK_CUDA(cudaGetLastError());
+
+    cuda_free(buf);
+    return 0;
+}
+
+int cuda_attention_f32_timed(
+    const float* Q, const float* K, const float* V, float* out,
+    int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
+    float scale, int startPos, float* ms
+) {
+    cudaEvent_t evStart;
+    cudaEvent_t evStop;
+    CHECK_CUDA(cudaEventCreate(&evStart));
+    CHECK_CUDA(cudaEventCreate(&evStop));
+
+    dim3 blocks(seqLen, numHeads);
+    int threads = 256;
+    size_t sharedMem = kvLen * sizeof(float);
+
+    CHECK_CUDA(cudaEventRecord(evStart));
+    attention_kernel<<<blocks, threads, sharedMem>>>(
+        Q, K, V, out, seqLen, kvLen, numHeads, numKVHeads, headDim, scale, startPos
+    );
+    CHECK_CUDA(cudaEventRecord(evStop));
+    CHECK_CUDA(cudaEventSynchronize(evStop));
+
+    float elapsed = 0.0f;
+    CHECK_CUDA(cudaEventElapsedTime(&elapsed, evStart, evStop));
+    if (ms != NULL) {
+        *ms = elapsed;
+    }
+
+    CHECK_CUDA(cudaEventDestroy(evStart));
+    CHECK_CUDA(cudaEventDestroy(evStop));
+    CHECK_CUDA(cudaGetLastError());
+    return 0;
+}
+
+// Debug helper
+int cuda_print_struct_sizes() {
+    printf("GPU Struct Sizes:\n");
+    printf("BlockQ2_K: %lu\n", sizeof(BlockQ2_K));
+    printf("BlockQ3_K: %lu\n", sizeof(BlockQ3_K));
+    printf("BlockQ4_K: %lu\n", sizeof(BlockQ4_K));
+    printf("BlockQ6_K: %lu\n", sizeof(BlockQ6_K));
+    printf("BlockQ8_K: %lu\n", sizeof(BlockQ8_K));
+    return 0;
+}

+ 221 - 0
pkg/backend/cuda/cuda_stub.go

@@ -0,0 +1,221 @@
+//go:build !cuda
+
+package cuda
+
+import (
+	"errors"
+	"unsafe"
+
+	"makarna/pkg/tensor"
+)
+
+var ErrCUDANotAvailable = errors.New("CUDA support not compiled in - build with -tags=cuda")
+
+// MemoryInfo returns (total, free) bytes for the current CUDA device.
+// In non-CUDA builds this always returns ErrCUDANotAvailable.
+func MemoryInfo() (total uint64, free uint64, err error) {
+	return 0, 0, ErrCUDANotAvailable
+}
+
+func MemoryInfoDevice(gpu int) (total uint64, free uint64, err error) {
+	return 0, 0, ErrCUDANotAvailable
+}
+
+func DeviceCount() (int, error) {
+	return 0, ErrCUDANotAvailable
+}
+
+// Tensor is a stub when CUDA is not available
+type Tensor struct {
+	shape tensor.Shape
+	dtype tensor.DType
+	gpu   int
+}
+
+func NewTensor(shape tensor.Shape, dtype tensor.DType, gpu int) (*Tensor, error) {
+	return nil, ErrCUDANotAvailable
+}
+
+func (t *Tensor) Shape() tensor.Shape       { return nil }
+func (t *Tensor) DType() tensor.DType       { return 0 }
+func (t *Tensor) Device() tensor.DeviceType { return tensor.CPU }
+func (t *Tensor) GPU() int                  { return -1 }
+func (t *Tensor) Placement() tensor.DevicePlacement {
+	return tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}
+}
+func (t *Tensor) Data() interface{}                                 { return nil }
+func (t *Tensor) Free()                                             {}
+func (t *Tensor) Add(other tensor.Tensor) error                     { return ErrCUDANotAvailable }
+func (t *Tensor) Mul(other tensor.Tensor) error                     { return ErrCUDANotAvailable }
+func (t *Tensor) MatMul(other, out tensor.Tensor) error             { return ErrCUDANotAvailable }
+func (t *Tensor) Reshape(shape tensor.Shape) (tensor.Tensor, error) { return nil, ErrCUDANotAvailable }
+func (t *Tensor) View(shape tensor.Shape) (tensor.Tensor, error)    { return nil, ErrCUDANotAvailable }
+func (t *Tensor) ViewAt(shape tensor.Shape, offsetBytes uintptr) (*Tensor, error) {
+	return nil, ErrCUDANotAvailable
+}
+func (t *Tensor) ToDevice(device tensor.DeviceType) (tensor.Tensor, error) {
+	return nil, ErrCUDANotAvailable
+}
+func (t *Tensor) CopyFrom(data interface{}) error                    { return ErrCUDANotAvailable }
+func (t *Tensor) CopyToHost(dst []float32) error                     { return ErrCUDANotAvailable }
+func (t *Tensor) CopyPartialFrom(dstOffset int, src []float32) error { return ErrCUDANotAvailable }
+func (t *Tensor) CopyPartialFromDevice(dstOffset int, src *Tensor, srcOffset int, length int) error {
+	return ErrCUDANotAvailable
+}
+
+func MemcpyH2D(dst, src unsafe.Pointer, size uintptr, gpu int) error { return ErrCUDANotAvailable }
+
+func MemcpyD2H(dst, src unsafe.Pointer, size uintptr, gpu int) error { return ErrCUDANotAvailable }
+
+func MemcpyD2D(dst, src unsafe.Pointer, size uintptr, gpu int) error { return ErrCUDANotAvailable }
+
+func CastF32ToF16(srcF32, dstF16 unsafe.Pointer, n int, gpu int) error { return ErrCUDANotAvailable }
+
+func KDACausalShortConv1D(x, state, w unsafe.Pointer, tokens, projSize, kernel int, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func L2NormHeads(q, k unsafe.Pointer, tokens, numHeads, headDim int, eps float32, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func KDAGate(g, aLog, dtBias, out unsafe.Pointer, tokens, numHeads, headDim int, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func KDARecurrent(q, k, v, g, beta, state unsafe.Pointer, tokens, numHeads, headDim int, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func RMSNormGated(out, g, weight unsafe.Pointer, n, headDim int, eps float32, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func Sigmoid(x unsafe.Pointer, n int, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func SoftmaxRows(x unsafe.Pointer, rows, cols int, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func TopKPerRow(scores unsafe.Pointer, indices unsafe.Pointer, values unsafe.Pointer, rows, cols, k int, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func PagedAttention(Q, kBlocksDev, vBlocksDev, out unsafe.Pointer, seqLen, kvLen, numHeads, numKVHeads, headDim, blockSize int, scale float32, startPos int, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func PagedAttentionBatch(Q, kBlocksFlatDev, vBlocksFlatDev, blockOffsetsDev, kvLensDev, queryPosDev, out unsafe.Pointer, numTokens, numHeads, numKVHeads, headDim, blockSize int, scale float32, maxKvLen int, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func PagedAttentionF32F16KV(Q, kBlocksDev, vBlocksDev, out unsafe.Pointer, seqLen, kvLen, numHeads, numKVHeads, headDim, blockSize int, scale float32, startPos int, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func PagedAttentionBatchF32F16KV(Q, kBlocksFlatDev, vBlocksFlatDev, blockOffsetsDev, kvLensDev, queryPosDev, out unsafe.Pointer, numTokens, numHeads, numKVHeads, headDim, blockSize int, scale float32, maxKvLen int, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func PagedAttentionRoPEF32F16KV(Q, kBlocksDev, vBlocksDev, out unsafe.Pointer, seqLen, kvLen, numHeads, numKVHeads, headDim, blockSize int, scale float32, startPos int, theta float32, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func PagedAttentionBatchRoPEF32F16KV(Q, kBlocksFlatDev, vBlocksFlatDev, blockOffsetsDev, kvLensDev, queryPosDev, out unsafe.Pointer, numTokens, numHeads, numKVHeads, headDim, blockSize int, scale float32, maxKvLen int, theta float32, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func AllocAndCopyInt32(data []int32, gpu int) (unsafe.Pointer, error) {
+	return nil, ErrCUDANotAvailable
+}
+
+func TopKLogitsF32(logits unsafe.Pointer, vocab int, repIDs []int32, repPenalty float32, k int, gpu int) ([]int32, []float32, int, error) {
+	return nil, nil, 0, ErrCUDANotAvailable
+}
+
+func DequantQ8K(blocks unsafe.Pointer, out unsafe.Pointer, numBlocks int, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func DequantQ4K(blocks unsafe.Pointer, out unsafe.Pointer, numBlocks int, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func DequantQ5K(blocks unsafe.Pointer, out unsafe.Pointer, numBlocks int, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func DequantQ6K(blocks unsafe.Pointer, out unsafe.Pointer, numBlocks int, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func DequantQ3K(blocks unsafe.Pointer, out unsafe.Pointer, numBlocks int, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func DequantQ2K(blocks unsafe.Pointer, out unsafe.Pointer, numBlocks int, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func MatMulQ2K(aPtr, bPtr, cPtr unsafe.Pointer, m, k, n int, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func MatMulQ4K(aPtr, bPtr, cPtr unsafe.Pointer, m, k, n int, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func MatMulQ5K(aPtr, bPtr, cPtr unsafe.Pointer, m, k, n int, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func MatMulQ3K(aPtr, bPtr, cPtr unsafe.Pointer, m, k, n int, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func MatMulQ6K(aPtr, bPtr, cPtr unsafe.Pointer, m, k, n int, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func MatMulQ8K(aPtr, bPtr, cPtr unsafe.Pointer, m, k, n int, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func MatMulF16Q8K(aPtr, bPtr, cPtr unsafe.Pointer, m, k, n int, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func MatMulF16Q4K(aPtr, bPtr, cPtr unsafe.Pointer, m, k, n int, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func MatMulF16Q5K(aPtr, bPtr, cPtr unsafe.Pointer, m, k, n int, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func MatMulF16Q2K(aPtr, bPtr, cPtr unsafe.Pointer, m, k, n int, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func MatMulF16Q3K(aPtr, bPtr, cPtr unsafe.Pointer, m, k, n int, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func MatMulF16Q6K(aPtr, bPtr, cPtr unsafe.Pointer, m, k, n int, gpu int) error {
+	return ErrCUDANotAvailable
+}
+
+func FreeDevicePtr(ptr unsafe.Pointer) {}
+
+func Free(ptr unsafe.Pointer) {}
+
+func AllocAndCopyPtrTable(ptrs []uintptr, gpu int) (unsafe.Pointer, error) {
+	return nil, ErrCUDANotAvailable
+}
+
+// Available returns whether CUDA is available
+func Available() bool {
+	return false
+}

+ 527 - 0
pkg/backend/cuda/dequant_test.go

@@ -0,0 +1,527 @@
+//go:build cuda
+
+package cuda
+
+import (
+	"testing"
+	"unsafe"
+
+	"makarna/pkg/quant"
+	"makarna/pkg/tensor"
+)
+
+func TestDequantQ8K_CUDA(t *testing.T) {
+	// Create a simple Q8_K block
+	// Block layout: 4 bytes D (float32) + 256 bytes qs (int8) + 32 bytes bsums
+	blockSize := 292
+	hostBlock := make([]byte, blockSize)
+	
+	// Set D = 0.5 (as float32 bytes)
+	d := float32(0.5)
+	dBytes := (*[4]byte)(unsafe.Pointer(&d))[:]
+	copy(hostBlock[0:4], dBytes)
+	
+	// Set qs: values 0, 1, 2, 3, ... (as int8)
+	for i := 0; i < 256; i++ {
+		hostBlock[4+i] = byte(int8(i - 128)) // Range -128 to 127
+	}
+	
+	// Upload block to GPU
+	gpu := 0
+	devBlocks, err := UploadQ8K(hostBlock, 1, gpu)
+	if err != nil {
+		t.Fatalf("UploadQ8K failed: %v", err)
+	}
+	defer FreeDevicePtr(devBlocks)
+	
+	// Allocate output on GPU
+	outTensor, err := NewTensor(tensor.Shape{256}, tensor.Float32, gpu)
+	if err != nil {
+		t.Fatalf("NewTensor failed: %v", err)
+	}
+	
+	// Dequantize
+	err = DequantQ8K(devBlocks, outTensor.ptr, 1, gpu)
+	if err != nil {
+		t.Fatalf("DequantQ8K failed: %v", err)
+	}
+	
+	// Copy back and verify
+	hostOut := make([]float32, 256)
+	if err := outTensor.CopyToHost(hostOut); err != nil {
+		t.Fatalf("CopyToHost failed: %v", err)
+	}
+	
+	// Check first few values
+	for i := 0; i < 10; i++ {
+		expected := float32(0.5) * float32(int8(i-128))
+		if diff := hostOut[i] - expected; diff < -0.001 || diff > 0.001 {
+			t.Errorf("out[%d] = %f, expected %f", i, hostOut[i], expected)
+		}
+	}
+	
+	t.Logf("Q8_K CUDA dequant test passed, sample outputs: %.4f, %.4f, %.4f", 
+		hostOut[0], hostOut[128], hostOut[255])
+}
+
+func TestMatMulQ8K_CUDA(t *testing.T) {
+	// Simple 2x4 @ Q8K(4x4) = 2x4 test
+	// But Q8K needs K to be multiple of 256, so we use M=2, K=256, N=2
+	M, K, N := 2, 256, 2
+	gpu := 0
+	
+	// Create input A on GPU [2, 256]
+	aTensor, err := NewTensor(tensor.Shape{M, K}, tensor.Float32, gpu)
+	if err != nil {
+		t.Fatalf("NewTensor A failed: %v", err)
+	}
+	
+	// Fill A with 1.0
+	hostA := make([]float32, M*K)
+	for i := range hostA {
+		hostA[i] = 1.0
+	}
+	if err := aTensor.CopyFrom(hostA); err != nil {
+		t.Fatalf("CopyFrom A failed: %v", err)
+	}
+	
+	// Create Q8_K weight B: N rows, each with K/256 = 1 block
+	// Each block: d=1.0, qs=all 1s -> dequant = 1.0 for all
+	blockSize := 292
+	numBlocks := N * (K / 256) // 2 * 1 = 2 blocks
+	hostB := make([]byte, numBlocks*blockSize)
+	
+	d := float32(1.0)
+	dBytes := (*[4]byte)(unsafe.Pointer(&d))[:]
+	
+	for blk := 0; blk < numBlocks; blk++ {
+		offset := blk * blockSize
+		copy(hostB[offset:offset+4], dBytes)
+		// qs = all 1s
+		for i := 0; i < 256; i++ {
+			hostB[offset+4+i] = 1
+		}
+	}
+	
+	devB, err := UploadQ8K(hostB, numBlocks, gpu)
+	if err != nil {
+		t.Fatalf("UploadQ8K B failed: %v", err)
+	}
+	defer FreeDevicePtr(devB)
+	
+	// Create output C on GPU [2, 2]
+	cTensor, err := NewTensor(tensor.Shape{M, N}, tensor.Float32, gpu)
+	if err != nil {
+		t.Fatalf("NewTensor C failed: %v", err)
+	}
+	
+	// Run fused matmul
+	err = MatMulQ8K(aTensor.ptr, devB, cTensor.ptr, M, K, N, gpu)
+	if err != nil {
+		t.Fatalf("MatMulQ8K failed: %v", err)
+	}
+	
+	// Copy back and verify
+	// C = A @ dequant(B) = [1,1,...] @ [1,1,...].T = 256.0 per element
+	hostC := make([]float32, M*N)
+	if err := cTensor.CopyToHost(hostC); err != nil {
+		t.Fatalf("CopyToHost C failed: %v", err)
+	}
+	
+	expected := float32(256.0) // Sum of 256 1s
+	for i, v := range hostC {
+		if diff := v - expected; diff < -1.0 || diff > 1.0 {
+			t.Errorf("C[%d] = %f, expected %f", i, v, expected)
+		}
+	}
+	
+	t.Logf("MatMulQ8K CUDA test passed, outputs: %v", hostC)
+}
+
+func TestMatMulF16Q8K_CUDA(t *testing.T) {
+	// Same as TestMatMulQ8K_CUDA but uses FP16 input kernel.
+	M, K, N := 2, 256, 2
+	gpu := 0
+
+	// Create input A on GPU [2, 256] as FP32 then cast to FP16 on GPU.
+	aF32, err := NewTensor(tensor.Shape{M, K}, tensor.Float32, gpu)
+	if err != nil {
+		t.Fatalf("NewTensor A(F32) failed: %v", err)
+	}
+
+	aF16, err := NewTensor(tensor.Shape{M, K}, tensor.Float16, gpu)
+	if err != nil {
+		t.Fatalf("NewTensor A(F16) failed: %v", err)
+	}
+
+	// Fill A with 1.0
+	hostA := make([]float32, M*K)
+	for i := range hostA {
+		hostA[i] = 1.0
+	}
+	if err := aF32.CopyFrom(hostA); err != nil {
+		t.Fatalf("CopyFrom A(F32) failed: %v", err)
+	}
+	if err := CastF32ToF16(aF32.ptr, aF16.ptr, M*K, gpu); err != nil {
+		t.Fatalf("CastF32ToF16 failed: %v", err)
+	}
+
+	// Create Q8_K weight B: N rows, each with K/256 = 1 block
+	blockSize := 292
+	numBlocks := N * (K / 256)
+	hostB := make([]byte, numBlocks*blockSize)
+
+	d := float32(1.0)
+	dBytes := (*[4]byte)(unsafe.Pointer(&d))[:]
+	for blk := 0; blk < numBlocks; blk++ {
+		offset := blk * blockSize
+		copy(hostB[offset:offset+4], dBytes)
+		for i := 0; i < 256; i++ {
+			hostB[offset+4+i] = 1
+		}
+	}
+
+	devB, err := UploadQ8K(hostB, numBlocks, gpu)
+	if err != nil {
+		t.Fatalf("UploadQ8K B failed: %v", err)
+	}
+	defer FreeDevicePtr(devB)
+
+	// Create output C on GPU [2, 2]
+	cTensor, err := NewTensor(tensor.Shape{M, N}, tensor.Float32, gpu)
+	if err != nil {
+		t.Fatalf("NewTensor C failed: %v", err)
+	}
+
+	// Run fused matmul (FP16 input)
+	err = MatMulF16Q8K(aF16.ptr, devB, cTensor.ptr, M, K, N, gpu)
+	if err != nil {
+		t.Fatalf("MatMulF16Q8K failed: %v", err)
+	}
+
+	// Copy back and verify
+	hostC := make([]float32, M*N)
+	if err := cTensor.CopyToHost(hostC); err != nil {
+		t.Fatalf("CopyToHost C failed: %v", err)
+	}
+
+	expected := float32(256.0)
+	for i, v := range hostC {
+		if diff := v - expected; diff < -1.0 || diff > 1.0 {
+			t.Errorf("C[%d] = %f, expected %f", i, v, expected)
+		}
+	}
+
+	t.Logf("MatMulF16Q8K CUDA test passed, outputs: %v", hostC)
+}
+
+func TestMatMulF16Q4K_CUDA(t *testing.T) {
+	M, K, N := 2, 256, 2
+	gpu := 0
+
+	aF32, err := NewTensor(tensor.Shape{M, K}, tensor.Float32, gpu)
+	if err != nil {
+		t.Fatalf("NewTensor A(F32) failed: %v", err)
+	}
+	aF16, err := NewTensor(tensor.Shape{M, K}, tensor.Float16, gpu)
+	if err != nil {
+		t.Fatalf("NewTensor A(F16) failed: %v", err)
+	}
+
+	hostA := make([]float32, M*K)
+	for i := range hostA {
+		hostA[i] = 1.0
+	}
+	if err := aF32.CopyFrom(hostA); err != nil {
+		t.Fatalf("CopyFrom A(F32) failed: %v", err)
+	}
+	if err := CastF32ToF16(aF32.ptr, aF16.ptr, M*K, gpu); err != nil {
+		t.Fatalf("CastF32ToF16 failed: %v", err)
+	}
+
+	row := make([]float32, K)
+	for i := range row {
+		row[i] = 1.0
+	}
+	hostB := make([]byte, 0, N*144)
+	for i := 0; i < N; i++ {
+		hostB = append(hostB, quant.QuantizeQ4K(row)...)
+	}
+
+	devB, err := UploadQ4K(hostB, N*(K/256), gpu)
+	if err != nil {
+		t.Fatalf("UploadQ4K B failed: %v", err)
+	}
+	defer FreeDevicePtr(devB)
+
+	cTensor, err := NewTensor(tensor.Shape{M, N}, tensor.Float32, gpu)
+	if err != nil {
+		t.Fatalf("NewTensor C failed: %v", err)
+	}
+
+	err = MatMulF16Q4K(aF16.ptr, devB, cTensor.ptr, M, K, N, gpu)
+	if err != nil {
+		t.Fatalf("MatMulF16Q4K failed: %v", err)
+	}
+
+	hostC := make([]float32, M*N)
+	if err := cTensor.CopyToHost(hostC); err != nil {
+		t.Fatalf("CopyToHost C failed: %v", err)
+	}
+
+	// Quantization may introduce small error; allow a bit more tolerance.
+	expected := float32(256.0)
+	for i, v := range hostC {
+		if diff := v - expected; diff < -4.0 || diff > 4.0 {
+			t.Errorf("C[%d] = %f, expected ~%f", i, v, expected)
+		}
+	}
+}
+
+func TestMatMulF16Q5K_CUDA(t *testing.T) {
+	M, K, N := 2, 256, 2
+	gpu := 0
+
+	aF32, err := NewTensor(tensor.Shape{M, K}, tensor.Float32, gpu)
+	if err != nil {
+		t.Fatalf("NewTensor A(F32) failed: %v", err)
+	}
+	aF16, err := NewTensor(tensor.Shape{M, K}, tensor.Float16, gpu)
+	if err != nil {
+		t.Fatalf("NewTensor A(F16) failed: %v", err)
+	}
+
+	hostA := make([]float32, M*K)
+	for i := range hostA {
+		hostA[i] = 1.0
+	}
+	if err := aF32.CopyFrom(hostA); err != nil {
+		t.Fatalf("CopyFrom A(F32) failed: %v", err)
+	}
+	if err := CastF32ToF16(aF32.ptr, aF16.ptr, M*K, gpu); err != nil {
+		t.Fatalf("CastF32ToF16 failed: %v", err)
+	}
+
+	row := make([]float32, K)
+	for i := range row {
+		row[i] = 1.0
+	}
+	hostB := make([]byte, 0, N*176)
+	for i := 0; i < N; i++ {
+		hostB = append(hostB, quant.QuantizeQ5K(row)...)
+	}
+
+	devB, err := UploadQ5K(hostB, N*(K/256), gpu)
+	if err != nil {
+		t.Fatalf("UploadQ5K B failed: %v", err)
+	}
+	defer FreeDevicePtr(devB)
+
+	cTensor, err := NewTensor(tensor.Shape{M, N}, tensor.Float32, gpu)
+	if err != nil {
+		t.Fatalf("NewTensor C failed: %v", err)
+	}
+
+	err = MatMulF16Q5K(aF16.ptr, devB, cTensor.ptr, M, K, N, gpu)
+	if err != nil {
+		t.Fatalf("MatMulF16Q5K failed: %v", err)
+	}
+
+	hostC := make([]float32, M*N)
+	if err := cTensor.CopyToHost(hostC); err != nil {
+		t.Fatalf("CopyToHost C failed: %v", err)
+	}
+
+	expected := float32(256.0)
+	for i, v := range hostC {
+		if diff := v - expected; diff < -4.0 || diff > 4.0 {
+			t.Errorf("C[%d] = %f, expected ~%f", i, v, expected)
+		}
+	}
+}
+
+func TestMatMulF16Q2K_CUDA(t *testing.T) {
+	M, K, N := 2, 256, 2
+	gpu := 0
+
+	aF32, err := NewTensor(tensor.Shape{M, K}, tensor.Float32, gpu)
+	if err != nil {
+		t.Fatalf("NewTensor A(F32) failed: %v", err)
+	}
+	aF16, err := NewTensor(tensor.Shape{M, K}, tensor.Float16, gpu)
+	if err != nil {
+		t.Fatalf("NewTensor A(F16) failed: %v", err)
+	}
+
+	hostA := make([]float32, M*K)
+	for i := range hostA {
+		hostA[i] = 1.0
+	}
+	if err := aF32.CopyFrom(hostA); err != nil {
+		t.Fatalf("CopyFrom A(F32) failed: %v", err)
+	}
+	if err := CastF32ToF16(aF32.ptr, aF16.ptr, M*K, gpu); err != nil {
+		t.Fatalf("CastF32ToF16 failed: %v", err)
+	}
+
+	row := make([]float32, K)
+	for i := range row {
+		row[i] = 1.0
+	}
+	hostB := make([]byte, 0, N*84)
+	for i := 0; i < N; i++ {
+		hostB = append(hostB, quant.QuantizeQ2K(row)...)
+	}
+
+	devB, err := UploadQ2K(hostB, N*(K/256), gpu)
+	if err != nil {
+		t.Fatalf("UploadQ2K B failed: %v", err)
+	}
+	defer FreeDevicePtr(devB)
+
+	cTensor, err := NewTensor(tensor.Shape{M, N}, tensor.Float32, gpu)
+	if err != nil {
+		t.Fatalf("NewTensor C failed: %v", err)
+	}
+
+	err = MatMulF16Q2K(aF16.ptr, devB, cTensor.ptr, M, K, N, gpu)
+	if err != nil {
+		t.Fatalf("MatMulF16Q2K failed: %v", err)
+	}
+
+	hostC := make([]float32, M*N)
+	if err := cTensor.CopyToHost(hostC); err != nil {
+		t.Fatalf("CopyToHost C failed: %v", err)
+	}
+
+	expected := float32(256.0)
+	for i, v := range hostC {
+		if diff := v - expected; diff < -12.0 || diff > 12.0 {
+			t.Errorf("C[%d] = %f, expected ~%f", i, v, expected)
+		}
+	}
+}
+
+func TestMatMulF16Q3K_CUDA(t *testing.T) {
+	M, K, N := 2, 256, 2
+	gpu := 0
+
+	aF32, err := NewTensor(tensor.Shape{M, K}, tensor.Float32, gpu)
+	if err != nil {
+		t.Fatalf("NewTensor A(F32) failed: %v", err)
+	}
+	aF16, err := NewTensor(tensor.Shape{M, K}, tensor.Float16, gpu)
+	if err != nil {
+		t.Fatalf("NewTensor A(F16) failed: %v", err)
+	}
+
+	hostA := make([]float32, M*K)
+	for i := range hostA {
+		hostA[i] = 1.0
+	}
+	if err := aF32.CopyFrom(hostA); err != nil {
+		t.Fatalf("CopyFrom A(F32) failed: %v", err)
+	}
+	if err := CastF32ToF16(aF32.ptr, aF16.ptr, M*K, gpu); err != nil {
+		t.Fatalf("CastF32ToF16 failed: %v", err)
+	}
+
+	row := make([]float32, K)
+	for i := range row {
+		row[i] = 1.0
+	}
+	hostB := make([]byte, 0, N*110)
+	for i := 0; i < N; i++ {
+		hostB = append(hostB, quant.QuantizeQ3K(row)...)
+	}
+
+	devB, err := UploadQ3K(hostB, N*(K/256), gpu)
+	if err != nil {
+		t.Fatalf("UploadQ3K B failed: %v", err)
+	}
+	defer FreeDevicePtr(devB)
+
+	cTensor, err := NewTensor(tensor.Shape{M, N}, tensor.Float32, gpu)
+	if err != nil {
+		t.Fatalf("NewTensor C failed: %v", err)
+	}
+
+	err = MatMulF16Q3K(aF16.ptr, devB, cTensor.ptr, M, K, N, gpu)
+	if err != nil {
+		t.Fatalf("MatMulF16Q3K failed: %v", err)
+	}
+
+	hostC := make([]float32, M*N)
+	if err := cTensor.CopyToHost(hostC); err != nil {
+		t.Fatalf("CopyToHost C failed: %v", err)
+	}
+
+	expected := float32(256.0)
+	for i, v := range hostC {
+		if diff := v - expected; diff < -12.0 || diff > 12.0 {
+			t.Errorf("C[%d] = %f, expected ~%f", i, v, expected)
+		}
+	}
+}
+
+func TestMatMulF16Q6K_CUDA(t *testing.T) {
+	M, K, N := 2, 256, 2
+	gpu := 0
+
+	aF32, err := NewTensor(tensor.Shape{M, K}, tensor.Float32, gpu)
+	if err != nil {
+		t.Fatalf("NewTensor A(F32) failed: %v", err)
+	}
+	aF16, err := NewTensor(tensor.Shape{M, K}, tensor.Float16, gpu)
+	if err != nil {
+		t.Fatalf("NewTensor A(F16) failed: %v", err)
+	}
+
+	hostA := make([]float32, M*K)
+	for i := range hostA {
+		hostA[i] = 1.0
+	}
+	if err := aF32.CopyFrom(hostA); err != nil {
+		t.Fatalf("CopyFrom A(F32) failed: %v", err)
+	}
+	if err := CastF32ToF16(aF32.ptr, aF16.ptr, M*K, gpu); err != nil {
+		t.Fatalf("CastF32ToF16 failed: %v", err)
+	}
+
+	row := make([]float32, K)
+	for i := range row {
+		row[i] = 1.0
+	}
+	hostB := make([]byte, 0, N*210)
+	for i := 0; i < N; i++ {
+		hostB = append(hostB, quant.QuantizeQ6K(row)...)
+	}
+
+	devB, err := UploadQ6K(hostB, N*(K/256), gpu)
+	if err != nil {
+		t.Fatalf("UploadQ6K B failed: %v", err)
+	}
+	defer FreeDevicePtr(devB)
+
+	cTensor, err := NewTensor(tensor.Shape{M, N}, tensor.Float32, gpu)
+	if err != nil {
+		t.Fatalf("NewTensor C failed: %v", err)
+	}
+
+	err = MatMulF16Q6K(aF16.ptr, devB, cTensor.ptr, M, K, N, gpu)
+	if err != nil {
+		t.Fatalf("MatMulF16Q6K failed: %v", err)
+	}
+
+	hostC := make([]float32, M*N)
+	if err := cTensor.CopyToHost(hostC); err != nil {
+		t.Fatalf("CopyToHost C failed: %v", err)
+	}
+
+	expected := float32(256.0)
+	for i, v := range hostC {
+		if diff := v - expected; diff < -8.0 || diff > 8.0 {
+			t.Errorf("C[%d] = %f, expected ~%f", i, v, expected)
+		}
+	}
+}

+ 8 - 0
pkg/backend/cuda/kernels.cu

@@ -0,0 +1,8 @@
+#include "cuda_memory.cu"
+#include "cuda_elementwise.cu"
+#include "cuda_dequant_q8k.cu"
+#include "cuda_dequant_q4k.cu"
+#include "cuda_dequant_q5k.cu"
+#include "cuda_dequant_other.cu"
+#include "cuda_matmul.cu"
+#include "cuda_nn.cu"

+ 344 - 0
pkg/backend/cuda/kernels.h

@@ -0,0 +1,344 @@
+#ifndef MAKARNA_CUDA_H
+#define MAKARNA_CUDA_H
+
+#include <stddef.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// Memory Management
+int cuda_set_device(int id);
+void* cuda_malloc(size_t size);
+void cuda_free(void* ptr);
+int cuda_synchronize();
+int cuda_memcpy_h2d(void* dst, void* src, size_t size);
+int cuda_memcpy_d2h(void* dst, void* src, size_t size);
+int cuda_memcpy_d2d(void* dst, void* src, size_t size);
+int cuda_mem_info(size_t* free_bytes, size_t* total_bytes);
+int cuda_device_count(int* count);
+
+// Math Operations (Float32)
+// Launches kernels on the default stream
+int cuda_add_f32(float* a, float* b, size_t n);
+int cuda_mul_f32(float* a, float* b, size_t n);
+int cuda_matmul_f32(float* A, float* B, float* C, int M, int K, int N);
+// MatMul where B is row-major [N, K] (no host transpose needed).
+int cuda_matmul_f32_nt(float* A, float* B, float* C, int M, int K, int N);
+// MatMul where A and B are float16 (IEEE half stored as uint16).
+// B is row-major [N, K] and interpreted as column-major [K, N].
+int cuda_matmul_f16_nt(const unsigned short* A, const unsigned short* B, float* C, int M, int K, int N);
+
+// ============================================================
+// Neural Network Operations
+// ============================================================
+
+// RMSNorm: x = x * rsqrt(mean(x^2) + eps) * weight
+// x: [seqLen, dim], w: [dim] -> modifies x in-place
+int cuda_rmsnorm_f32(float* x, const float* w, int seqLen, int dim, float eps);
+
+// RoPE: Apply rotary positional embeddings in-place
+// x: [seqLen, numHeads * headDim]
+// positions: [seqLen] - position indices
+int cuda_rope_f32(float* x, const int* positions, int seqLen, int numHeads, int headDim, float theta);
+int cuda_rope_f32_single(float* x, int pos, int numHeads, int headDim, float theta);
+
+// Softmax: Apply softmax along last dimension
+// x: [rows, cols] -> in-place
+int cuda_softmax_f32(float* x, int rows, int cols);
+
+// Top-K selection on logits with optional repetition penalty.
+// logits: [vocab]
+// rep_ids: [rep_count] token ids to penalize
+// out_ids/out_scores: [numBlocks * k]
+// Returns 0 on success.
+int cuda_topk_logits_f32(
+    const float* logits, int vocab,
+    const int* rep_ids, int rep_count, float rep_penalty,
+    int k,
+    int* out_ids, float* out_scores);
+
+// Causal Attention: Full attention computation
+// Q: [seqLen, numHeads * headDim]
+// K: [kvLen, numKVHeads * headDim]  
+// V: [kvLen, numKVHeads * headDim]
+// out: [seqLen, numHeads * headDim]
+// scale: typically 1/sqrt(headDim)
+// startPos: for causal mask offset (KV cache)
+int cuda_attention_f32(
+    const float* Q, const float* K, const float* V, float* out,
+    int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
+    float scale, int startPos);
+
+int cuda_paged_attention_f32(
+    const float* Q,
+    const float* const* KBlocks,
+    const float* const* VBlocks,
+    float* out,
+    int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
+    int blockSize,
+    float scale, int startPos);
+
+int cuda_paged_attention_batch_f32(
+    const float* Q,
+    const float* const* KBlocksFlat,
+    const float* const* VBlocksFlat,
+    const int* blockOffsets,
+    const int* kvLens,
+    const int* queryPos,
+    float* out,
+    int numTokens,
+    int numHeads, int numKVHeads, int headDim,
+    int blockSize,
+    float scale,
+    int maxKvLen);
+
+// Paged attention where KV blocks are float16 (IEEE half stored as uint16).
+// Q and out are float32. Accumulation is float32.
+int cuda_paged_attention_f32_f16kv(
+    const float* Q,
+    const unsigned short* const* KBlocks,
+    const unsigned short* const* VBlocks,
+    float* out,
+    int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
+    int blockSize,
+    float scale, int startPos);
+
+int cuda_paged_attention_batch_f32_f16kv(
+    const float* Q,
+    const unsigned short* const* KBlocksFlat,
+    const unsigned short* const* VBlocksFlat,
+    const int* blockOffsets,
+    const int* kvLens,
+    const int* queryPos,
+    float* out,
+    int numTokens,
+    int numHeads, int numKVHeads, int headDim,
+    int blockSize,
+    float scale,
+    int maxKvLen);
+
+// Fused RoPE + paged attention where KV blocks are float16 (IEEE half stored as uint16).
+// Expects un-rotated Q and un-rotated K blocks; RoPE is applied on-the-fly in the attention kernel.
+int cuda_paged_attention_rope_f32_f16kv(
+    const float* Q,
+    const unsigned short* const* KBlocks,
+    const unsigned short* const* VBlocks,
+    float* out,
+    int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
+    int blockSize,
+    float scale, int startPos,
+    float theta);
+
+int cuda_paged_attention_rope_batch_f32_f16kv(
+    const float* Q,
+    const unsigned short* const* KBlocksFlat,
+    const unsigned short* const* VBlocksFlat,
+    const int* blockOffsets,
+    const int* kvLens,
+    const int* queryPos,
+    float* out,
+    int numTokens,
+    int numHeads, int numKVHeads, int headDim,
+    int blockSize,
+    float scale,
+    int maxKvLen,
+    float theta);
+
+// Cast float32 -> float16 (stored as uint16) on GPU.
+int cuda_cast_f32_to_f16(const float* src, unsigned short* dst, int n);
+
+int cuda_attention_f32_timed(
+    const float* Q, const float* K, const float* V, float* out,
+    int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
+    float scale, int startPos, float* ms);
+
+// SiLU activation: x = x * sigmoid(x), in-place
+int cuda_silu_f32(float* x, size_t n);
+
+// Element-wise multiply: a = a * b, in-place
+int cuda_mul_inplace_f32(float* a, const float* b, size_t n);
+
+// Copy: dst = src
+int cuda_copy_f32(float* dst, const float* src, size_t n);
+
+int cuda_kda_causal_short_conv1d_f32(
+    float* x,
+    float* state,
+    const float* w,
+    int tokens,
+    int projSize,
+    int kernel);
+
+int cuda_l2norm_heads_f32(
+    float* q,
+    float* k,
+    int tokens,
+    int numHeads,
+    int headDim,
+    float eps);
+
+int cuda_kda_gate_f32(
+    const float* g,
+    const float* aLog,
+    const float* dtBias,
+    float* out,
+    int tokens,
+    int numHeads,
+    int headDim);
+
+int cuda_kda_recurrent_f32(
+    const float* q,
+    const float* k,
+    float* v,
+    const float* g,
+    // beta is a device pointer: [tokens, numHeads] (row-major).
+    const float* beta,
+    float* state,
+    int tokens,
+    int numHeads,
+    int headDim);
+
+int cuda_rmsnorm_gated_f32(
+    float* out,
+    const float* g,
+    const float* weight,
+    int n,
+    int headDim,
+    float eps);
+
+int cuda_sigmoid_f32(float* x, int n);
+
+int cuda_softmax_rows_f32(float* x, int rows, int cols);
+
+int cuda_topk_per_row_f32(
+    const float* scores,
+    int* indices,
+    float* values,
+    int rows,
+    int cols,
+    int k);
+
+// ============================================================
+// Dequantization Kernels
+// These convert quantized blocks to float32 on GPU
+// ============================================================
+
+// Block sizes for K-quantization
+#define QK_K 256
+
+// BlockQ8_K: 292 bytes per block (4 + 256 + 32)
+// - D (4 bytes): float32 scale
+// - QS (256 bytes): 256 int8 quants
+// - BSums (32 bytes): unused in dequant
+typedef struct {
+    float d;
+    signed char qs[256];
+    short bsums[16];
+} BlockQ8_K;
+
+// BlockQ4_K: 144 bytes per block
+// - D (2 bytes): float16 super-scale
+// - DMin (2 bytes): float16 super-min
+// - Scales (12 bytes): packed 6-bit scales/mins
+// - QS (128 bytes): 256 4-bit quants
+typedef struct {
+    unsigned short d;
+    unsigned short dmin;
+    unsigned char scales[12];
+    unsigned char qs[128];
+} BlockQ4_K;
+
+typedef struct {
+    unsigned short d;
+    unsigned short dmin;
+    unsigned char scales[12];
+    unsigned char qh[32];
+    unsigned char qs[128];
+} BlockQ5_K;
+
+// BlockQ6_K: 210 bytes per block
+// - QL (128 bytes): lower 4 bits
+// - QH (64 bytes): upper 2 bits
+// - Scales (16 bytes): 8-bit scales
+// - D (2 bytes): float16 super-scale
+typedef struct {
+    unsigned char ql[128];
+    unsigned char qh[64];
+    signed char scales[16];
+    unsigned short d;
+} BlockQ6_K;
+
+// BlockQ3_K: 110 bytes per block
+// - HMask (32 bytes): high bits
+// - QS (64 bytes): low 2 bits
+// - Scales (12 bytes): packed 6-bit scales
+// - D (2 bytes): float16 super-scale
+typedef struct {
+    unsigned char hmask[32];
+    unsigned char qs[64];
+    unsigned char scales[12];
+    unsigned short d;
+} BlockQ3_K;
+
+// BlockQ2_K: 84 bytes per block
+// - Scales (16 bytes): packed 4-bit scales/mins
+// - QS (64 bytes): 256 2-bit quants
+// - D (2 bytes): float16 super-scale
+// - DMin (2 bytes): float16 super-min
+typedef struct {
+    unsigned char scales[16];
+    unsigned char qs[64];
+    unsigned short d;
+    unsigned short dmin;
+} BlockQ2_K;
+
+// Dequantize a row of Q8_K blocks: numBlocks * 256 values -> out
+int cuda_dequant_q8k(const void* blocks, float* out, int numBlocks);
+
+// Dequantize a row of Q4_K blocks
+int cuda_dequant_q4k(const void* blocks, float* out, int numBlocks);
+
+int cuda_dequant_q5k(const void* blocks, float* out, int numBlocks);
+
+// Dequantize a row of Q6_K blocks
+int cuda_dequant_q6k(const void* blocks, float* out, int numBlocks);
+
+// Dequantize a row of Q3_K blocks
+int cuda_dequant_q3k(const void* blocks, float* out, int numBlocks);
+
+// Dequantize a row of Q2_K blocks
+int cuda_dequant_q2k(const void* blocks, float* out, int numBlocks);
+
+// Fused Dequant + MatMul (for maximum performance)
+// A: [M, K] float32 input
+// B: quantized weight blocks [N rows, K/256 blocks per row]
+// C: [M, N] float32 output
+// This dequantizes B on-the-fly and computes C = A @ B.T
+int cuda_matmul_f32_q8k(float* A, const void* B, float* C, int M, int K, int N);
+int cuda_matmul_f32_q5k(float* A, const void* B, float* C, int M, int K, int N);
+int cuda_matmul_f32_q6k(float* A, const void* B, float* C, int M, int K, int N);
+int cuda_matmul_f32_q4k(float* A, const void* B, float* C, int M, int K, int N);
+int cuda_matmul_f32_q3k(float* A, const void* B, float* C, int M, int K, int N);
+int cuda_matmul_f32_q2k(float* A, const void* B, float* C, int M, int K, int N);
+
+int cuda_matmul_f32_q8k_timed(float* A, const void* B, float* C, int M, int K, int N, float* ms);
+int cuda_matmul_f32_q4k_timed(float* A, const void* B, float* C, int M, int K, int N, float* ms);
+
+// FP16 Input Variants - 2x memory bandwidth for activations
+// A: [M, K] float16 input, B: quantized, C: [M, N] float32 output
+int cuda_matmul_f16_q8k(const void* A, const void* B, float* C, int M, int K, int N);
+int cuda_matmul_f16_q4k(const void* A, const void* B, float* C, int M, int K, int N);
+int cuda_matmul_f16_q5k(const void* A, const void* B, float* C, int M, int K, int N);
+int cuda_matmul_f16_q2k(const void* A, const void* B, float* C, int M, int K, int N);
+int cuda_matmul_f16_q3k(const void* A, const void* B, float* C, int M, int K, int N);
+int cuda_matmul_f16_q6k(const void* A, const void* B, float* C, int M, int K, int N);
+
+// Debug helper
+int cuda_print_struct_sizes();
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // MAKARNA_CUDA_H

+ 254 - 0
pkg/backend/device/device.go

@@ -0,0 +1,254 @@
+// Package device provides cross-device tensor operations and placement management.
+// It serves as the central hub for device-aware computation in the makarna engine.
+package device
+
+import (
+	"fmt"
+	"sync"
+	"unsafe"
+
+	"makarna/pkg/backend/cpu"
+	"makarna/pkg/backend/cuda"
+	"makarna/pkg/tensor"
+)
+
+// WeightCache caches GPU copies of weights to avoid repeated H2D transfers.
+// Thread-safe for concurrent layer execution.
+type WeightCache struct {
+	mu     sync.RWMutex
+	cache  map[string]*cuda.Tensor // key: "layer_idx:weight_name"
+	gpuID  int
+}
+
+// NewWeightCache creates a new weight cache for a specific GPU.
+func NewWeightCache(gpuID int) *WeightCache {
+	return &WeightCache{
+		cache: make(map[string]*cuda.Tensor),
+		gpuID: gpuID,
+	}
+}
+
+// Get retrieves a cached GPU tensor, returning nil if not cached.
+func (wc *WeightCache) Get(key string) *cuda.Tensor {
+	wc.mu.RLock()
+	defer wc.mu.RUnlock()
+	return wc.cache[key]
+}
+
+// Put adds a GPU tensor to the cache.
+func (wc *WeightCache) Put(key string, t *cuda.Tensor) {
+	wc.mu.Lock()
+	defer wc.mu.Unlock()
+	wc.cache[key] = t
+}
+
+// Clear frees all cached GPU tensors.
+func (wc *WeightCache) Clear() {
+	wc.mu.Lock()
+	defer wc.mu.Unlock()
+	wc.cache = make(map[string]*cuda.Tensor)
+}
+
+// EnsureOn returns a tensor on the requested placement, copying if needed.
+// For CPU tensors going to CUDA, this creates a NEW tensor each time.
+// Use EnsureOnCached for weight tensors that should be cached.
+func EnsureOn(t tensor.Tensor, target tensor.DevicePlacement) (tensor.Tensor, error) {
+	if twp, ok := t.(tensor.TensorWithPlacement); ok {
+		if twp.Placement() == target.Normalize() {
+			return t, nil
+		}
+	}
+
+	switch target.Type {
+	case tensor.CPU:
+		return toCPU(t)
+	case tensor.CUDA:
+		return toCUDA(t, target.GPU)
+	default:
+		return nil, fmt.Errorf("unsupported target device %v", target.Type)
+	}
+}
+
+// EnsureOnCached is like EnsureOn but uses a cache for weight tensors.
+// The key should uniquely identify the weight (e.g., "layer_0:wq").
+func EnsureOnCached(t tensor.Tensor, target tensor.DevicePlacement, cache *WeightCache, key string) (tensor.Tensor, error) {
+	if target.Type != tensor.CUDA {
+		return EnsureOn(t, target)
+	}
+
+	if cache == nil {
+		return EnsureOn(t, target)
+	}
+
+	// Check cache first
+	if cached := cache.Get(key); cached != nil {
+		return cached, nil
+	}
+
+	// Not cached, create and cache
+	result, err := toCUDA(t, target.GPU)
+	if err != nil {
+		return nil, err
+	}
+
+	cudaTensor, ok := result.(*cuda.Tensor)
+	if ok {
+		cache.Put(key, cudaTensor)
+	}
+
+	return result, nil
+}
+
+// CUDAAvailable returns whether CUDA is available.
+func CUDAAvailable() bool {
+	return cuda.Available()
+}
+
+func toCPU(t tensor.Tensor) (tensor.Tensor, error) {
+	if c, ok := t.(*cpu.Tensor); ok {
+		return c, nil
+	}
+	switch src := t.(type) {
+	case *cuda.Tensor:
+		if !cuda.Available() {
+			return nil, fmt.Errorf("CUDA not available")
+		}
+		out := cpu.NewTensor(src.Shape(), nil)
+		host := out.DataFloat32()
+		if err := src.CopyToHost(host); err != nil {
+			return nil, fmt.Errorf("copy to host failed: %w", err)
+		}
+		return out, nil
+	default:
+		return nil, fmt.Errorf("toCPU: unsupported tensor type %T", t)
+	}
+}
+
+func toCUDA(t tensor.Tensor, gpu int) (tensor.Tensor, error) {
+	if !cuda.Available() {
+		return nil, fmt.Errorf("CUDA not available - build with -tags=cuda")
+	}
+
+	switch src := t.(type) {
+	case *cuda.Tensor:
+		if src.GPU() == gpu {
+			return src, nil
+		}
+		if src.DType() != tensor.Float32 {
+			return nil, fmt.Errorf("cross-GPU tensor copy only supports float32, got %v", src.DType())
+		}
+		out, err := cuda.NewTensor(src.Shape(), src.DType(), gpu)
+		if err != nil {
+			return nil, err
+		}
+		size := uintptr(src.Shape().NumElements() * src.DType().Size())
+		if err := cuda.MemcpyD2D(out.Data().(unsafe.Pointer), src.Data().(unsafe.Pointer), size, gpu); err != nil {
+			// Conservative fallback: stage via host.
+			host := make([]float32, src.Shape().NumElements())
+			if err2 := src.CopyToHost(host); err2 != nil {
+				out.Free()
+				return nil, fmt.Errorf("cross-GPU copy D2H failed: %w", err2)
+			}
+			if err2 := out.CopyFrom(host); err2 != nil {
+				out.Free()
+				return nil, fmt.Errorf("cross-GPU copy H2D failed: %w", err2)
+			}
+		}
+		return out, nil
+	}
+
+	// For quantized tensors, we need dequantization first
+	if t.DType() != tensor.Float32 {
+		return nil, fmt.Errorf("toCUDA: only float32 currently supported, got %v", t.DType())
+	}
+
+	out, err := cuda.NewTensor(t.Shape(), t.DType(), gpu)
+	if err != nil {
+		return nil, err
+	}
+
+	switch s := t.(type) {
+	case *cpu.Tensor:
+		if err := out.CopyFrom(s.DataFloat32()); err != nil {
+			return nil, err
+		}
+	default:
+		return nil, fmt.Errorf("toCUDA: unsupported source type %T", t)
+	}
+
+	return out, nil
+}
+
+// DeviceDispatcher manages per-device operations and caching.
+type DeviceDispatcher struct {
+	layerDevices []tensor.DevicePlacement
+	weightCaches map[int]*WeightCache // gpuID -> cache
+	mu           sync.RWMutex
+}
+
+// NewDeviceDispatcher creates a dispatcher with the given layer placements.
+func NewDeviceDispatcher(layerDevices []tensor.DevicePlacement) *DeviceDispatcher {
+	dd := &DeviceDispatcher{
+		layerDevices: layerDevices,
+		weightCaches: make(map[int]*WeightCache),
+	}
+
+	// Pre-create caches for each GPU mentioned
+	for _, p := range layerDevices {
+		if p.Type == tensor.CUDA {
+			if _, exists := dd.weightCaches[p.GPU]; !exists {
+				dd.weightCaches[p.GPU] = NewWeightCache(p.GPU)
+			}
+		}
+	}
+
+	return dd
+}
+
+// LayerPlacement returns the device placement for a layer.
+func (dd *DeviceDispatcher) LayerPlacement(layerIdx int) tensor.DevicePlacement {
+	if layerIdx >= 0 && layerIdx < len(dd.layerDevices) {
+		return dd.layerDevices[layerIdx]
+	}
+	return tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}
+}
+
+// GetWeightCache returns the weight cache for a GPU, creating one if needed.
+func (dd *DeviceDispatcher) GetWeightCache(gpuID int) *WeightCache {
+	dd.mu.Lock()
+	defer dd.mu.Unlock()
+
+	if cache, exists := dd.weightCaches[gpuID]; exists {
+		return cache
+	}
+
+	cache := NewWeightCache(gpuID)
+	dd.weightCaches[gpuID] = cache
+	return cache
+}
+
+// IsLayerOnGPU returns true if the layer should run on GPU.
+func (dd *DeviceDispatcher) IsLayerOnGPU(layerIdx int) bool {
+	p := dd.LayerPlacement(layerIdx)
+	return p.Type == tensor.CUDA
+}
+
+// NumGPULayers counts how many layers are placed on GPU.
+func (dd *DeviceDispatcher) NumGPULayers() int {
+	count := 0
+	for _, p := range dd.layerDevices {
+		if p.Type == tensor.CUDA {
+			count++
+		}
+	}
+	return count
+}
+
+// Clear frees all cached resources.
+func (dd *DeviceDispatcher) Clear() {
+	dd.mu.Lock()
+	defer dd.mu.Unlock()
+	for _, cache := range dd.weightCaches {
+		cache.Clear()
+	}
+}

+ 65 - 0
pkg/chat/parse_tool_calls.go

@@ -0,0 +1,65 @@
+package chat
+
+import (
+	"bytes"
+	"encoding/json"
+	"fmt"
+	"strings"
+)
+
+func StripThinking(s string) (reasoning string, content string) {
+	start := strings.Index(s, "<think>")
+	end := strings.Index(s, "</think>")
+	if start >= 0 && end > start {
+		reasoning = s[start+len("<think>"):end]
+		content = s[:start] + s[end+len("</think>"):]
+		return strings.Trim(reasoning, "\n"), strings.TrimLeft(content, "\n")
+	}
+	// If thinking starts but doesn't end (truncated generation), drop it from visible output.
+	if start >= 0 && end < 0 {
+		reasoning = s[start+len("<think>"):]
+		content = s[:start]
+		return strings.Trim(reasoning, "\n"), strings.TrimSpace(content)
+	}
+	return "", s
+}
+
+func ExtractToolCalls(s string) (content string, calls []ToolCall, err error) {
+	var out strings.Builder
+	rest := s
+	for {
+		start := strings.Index(rest, "<tool_call>")
+		if start < 0 {
+			out.WriteString(rest)
+			break
+		}
+		out.WriteString(rest[:start])
+		rest = rest[start+len("<tool_call>"):]
+		end := strings.Index(rest, "</tool_call>")
+		if end < 0 {
+			return "", nil, fmt.Errorf("unterminated <tool_call>")
+		}
+		block := strings.TrimSpace(rest[:end])
+		rest = rest[end+len("</tool_call>"):]
+
+		// Some models include trailing commas/newlines; keep robust.
+		block = strings.Trim(block, "\n\r\t ")
+		var raw struct {
+			Name      string          `json:"name"`
+			Arguments json.RawMessage `json:"arguments"`
+		}
+		dec := json.NewDecoder(bytes.NewReader([]byte(block)))
+		dec.DisallowUnknownFields()
+		if err := dec.Decode(&raw); err != nil {
+			// Fallback: allow unknown fields.
+			if err2 := json.Unmarshal([]byte(block), &raw); err2 != nil {
+				return "", nil, fmt.Errorf("parse tool_call json: %w", err)
+			}
+		}
+		if raw.Name == "" {
+			return "", nil, fmt.Errorf("tool_call missing name")
+		}
+		calls = append(calls, ToolCall{Name: raw.Name, Arguments: raw.Arguments})
+	}
+	return out.String(), calls, nil
+}

+ 38 - 0
pkg/chat/parse_tool_calls_test.go

@@ -0,0 +1,38 @@
+package chat
+
+import (
+	"encoding/json"
+	"testing"
+)
+
+func TestStripThinking(t *testing.T) {
+	r, c := StripThinking("<think>hi</think>\n\nhello")
+	if r != "hi" {
+		t.Fatalf("reasoning=%q", r)
+	}
+	if c != "hello" {
+		t.Fatalf("content=%q", c)
+	}
+}
+
+func TestStripThinking_Truncated(t *testing.T) {
+	_, c := StripThinking("before<think>this is truncated")
+	if c != "before" {
+		t.Fatalf("content=%q", c)
+	}
+}
+
+func TestExtractToolCalls(t *testing.T) {
+	args, _ := json.Marshal(map[string]any{"x": 1})
+	s := "before<tool_call>\n{\"name\":\"f\",\"arguments\":" + string(args) + "}\n</tool_call>after"
+	content, calls, err := ExtractToolCalls(s)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if content != "beforeafter" {
+		t.Fatalf("content=%q", content)
+	}
+	if len(calls) != 1 || calls[0].Name != "f" {
+		t.Fatalf("calls=%+v", calls)
+	}
+}

+ 51 - 0
pkg/chat/registry.go

@@ -0,0 +1,51 @@
+package chat
+
+import (
+	"fmt"
+	"strings"
+	"sync"
+)
+
+// Renderer renders messages into a single prompt string.
+//
+//nolint:revive // exported API
+type Renderer interface {
+	Render(messages []Message, opts Options) (string, error)
+}
+
+var (
+	mu        sync.RWMutex
+	renderers = map[string]Renderer{}
+)
+
+// Register registers a renderer for a model architecture (e.g. "qwen3", "llama").
+func Register(architecture string, r Renderer) {
+	mu.Lock()
+	defer mu.Unlock()
+	renderers[strings.ToLower(architecture)] = r
+}
+
+// RendererForArchitecture returns the renderer for the given architecture.
+func RendererForArchitecture(architecture string) (Renderer, bool) {
+	mu.RLock()
+	defer mu.RUnlock()
+	r, ok := renderers[strings.ToLower(architecture)]
+	return r, ok
+}
+
+// RenderForArchitecture renders chat messages for the given architecture.
+func RenderForArchitecture(architecture string, messages []Message, opts Options) (string, error) {
+	if r, ok := RendererForArchitecture(architecture); ok {
+		return r.Render(messages, opts)
+	}
+	return (&Qwen3Renderer{}).Render(messages, opts)
+}
+
+// MustRenderForArchitecture is a convenience helper.
+func MustRenderForArchitecture(architecture string, messages []Message, opts Options) string {
+	out, err := RenderForArchitecture(architecture, messages, opts)
+	if err != nil {
+		panic(fmt.Errorf("render chat template: %w", err))
+	}
+	return out
+}

+ 140 - 0
pkg/chat/render_qwen3.go

@@ -0,0 +1,140 @@
+package chat
+
+import (
+	"encoding/json"
+	"fmt"
+	"strings"
+)
+
+// Qwen3Renderer implements the Qwen3 chat template behavior commonly shipped
+// in tokenizer_config.json (including tool-call wrappers).
+// It is a deterministic renderer (not a general Jinja interpreter).
+//
+//nolint:revive // exported API
+type Qwen3Renderer struct{}
+
+func (r *Qwen3Renderer) Render(messages []Message, opts Options) (string, error) {
+	var sb strings.Builder
+
+	// Tools header block (only when tools are provided)
+	if len(opts.Tools) > 0 {
+		sb.WriteString("<|im_start|>system\n")
+		if len(messages) > 0 && messages[0].Role == "system" {
+			sb.WriteString(messages[0].Content)
+			sb.WriteString("\n\n")
+		}
+		sb.WriteString("# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>")
+		for _, tool := range opts.Tools {
+			b, err := json.Marshal(tool)
+			if err != nil {
+				return "", fmt.Errorf("marshal tool: %w", err)
+			}
+			sb.WriteString("\n")
+			sb.Write(b)
+		}
+		sb.WriteString("\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n")
+	} else {
+		// No-tools path: include system message only if first message is system.
+		if len(messages) > 0 && messages[0].Role == "system" {
+			sb.WriteString("<|im_start|>system\n")
+			sb.WriteString(messages[0].Content)
+			sb.WriteString("<|im_end|>\n")
+		}
+	}
+
+	// Find last user query index for thinking rendering.
+	lastQueryIdx := len(messages) - 1
+	multiStepTool := true
+	for i := len(messages) - 1; i >= 0; i-- {
+		m := messages[i]
+		if multiStepTool && m.Role == "user" && !strings.HasPrefix(m.Content, "<tool_response>") {
+			multiStepTool = false
+			lastQueryIdx = i
+			break
+		}
+	}
+
+	// Render message stream.
+	inToolGroup := false
+	for i := 0; i < len(messages); i++ {
+		m := messages[i]
+
+		switch m.Role {
+		case "tool":
+			// Tool responses are wrapped as a fake user message with <tool_response> blocks.
+			if !inToolGroup {
+				sb.WriteString("<|im_start|>user")
+				inToolGroup = true
+			}
+			sb.WriteString("\n<tool_response>\n")
+			sb.WriteString(m.Content)
+			sb.WriteString("\n</tool_response>")
+
+			// Close tool group if next isn't tool.
+			if i == len(messages)-1 || messages[i+1].Role != "tool" {
+				sb.WriteString("<|im_end|>\n")
+				inToolGroup = false
+			}
+			continue
+		default:
+			if inToolGroup {
+				// Safety: close an unterminated tool group.
+				sb.WriteString("<|im_end|>\n")
+				inToolGroup = false
+			}
+		}
+
+		content := m.Content
+		if m.Role == "user" || (m.Role == "system" && i != 0) {
+			sb.WriteString("<|im_start|>")
+			sb.WriteString(m.Role)
+			sb.WriteString("\n")
+			sb.WriteString(content)
+			sb.WriteString("<|im_end|>\n")
+			continue
+		}
+
+		if m.Role == "assistant" {
+			sb.WriteString("<|im_start|>assistant\n")
+
+			reasoning := m.ReasoningContent
+			if opts.EnableThinking {
+				// Render thinking only after the last user query index.
+				if i > lastQueryIdx {
+					sb.WriteString("<think>\n")
+					sb.WriteString(strings.Trim(reasoning, "\n"))
+					sb.WriteString("\n</think>\n\n")
+				}
+			}
+
+			sb.WriteString(strings.TrimLeft(content, "\n"))
+
+			// Render tool calls, if any.
+			if len(m.ToolCalls) > 0 {
+				for j, tc := range m.ToolCalls {
+					if (j == 0 && content != "") || j > 0 {
+						sb.WriteString("\n")
+					}
+					sb.WriteString("<tool_call>\n{\"name\": \"")
+					sb.WriteString(tc.Name)
+					sb.WriteString("\", \"arguments\": ")
+					if len(tc.Arguments) == 0 {
+						sb.WriteString("{}")
+					} else {
+						sb.Write(tc.Arguments)
+					}
+					sb.WriteString("}\n</tool_call>")
+				}
+			}
+
+			sb.WriteString("<|im_end|>\n")
+			continue
+		}
+	}
+
+	if opts.AddGenerationPrompt {
+		sb.WriteString("<|im_start|>assistant\n")
+	}
+
+	return sb.String(), nil
+}

+ 70 - 0
pkg/chat/render_qwen3_test.go

@@ -0,0 +1,70 @@
+package chat
+
+import (
+	"encoding/json"
+	"strings"
+	"testing"
+)
+
+func TestQwen3Renderer_ToolsAndToolCalls(t *testing.T) {
+	r := &Qwen3Renderer{}
+
+	toolSchema := map[string]any{
+		"name": "get_weather",
+		"description": "Get current weather",
+		"parameters": map[string]any{
+			"type": "object",
+			"properties": map[string]any{
+				"city": map[string]any{"type": "string"},
+			},
+		},
+	}
+
+	args, err := json.Marshal(map[string]any{"city": "Istanbul"})
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	out, err := r.Render(
+		[]Message{
+			{Role: "user", Content: "Weather?"},
+			{Role: "assistant", Content: "", ToolCalls: []ToolCall{{Name: "get_weather", Arguments: args}}},
+		},
+		Options{AddGenerationPrompt: false, EnableThinking: true, Tools: []any{toolSchema}},
+	)
+	if err != nil {
+		t.Fatalf("render failed: %v", err)
+	}
+
+	if !strings.Contains(out, "# Tools") {
+		t.Fatalf("missing tools header")
+	}
+	if !strings.Contains(out, "<tools>") || !strings.Contains(out, "</tools>") {
+		t.Fatalf("missing <tools> wrapper")
+	}
+	if !strings.Contains(out, "<tool_call>") || !strings.Contains(out, "\"name\": \"get_weather\"") {
+		t.Fatalf("missing tool_call encoding: %q", out)
+	}
+}
+
+func TestQwen3Renderer_ToolResponseGroup(t *testing.T) {
+	r := &Qwen3Renderer{}
+
+	out, err := r.Render(
+		[]Message{
+			{Role: "user", Content: "Do thing"},
+			{Role: "tool", Content: "{\"ok\":true}"},
+			{Role: "assistant", Content: "Done"},
+		},
+		Options{AddGenerationPrompt: false, EnableThinking: true},
+	)
+	if err != nil {
+		t.Fatalf("render failed: %v", err)
+	}
+	if !strings.Contains(out, "<tool_response>") || !strings.Contains(out, "</tool_response>") {
+		t.Fatalf("missing tool_response wrapper: %q", out)
+	}
+	if !strings.Contains(out, "<|im_start|>user\n<tool_response>") {
+		t.Fatalf("tool response not wrapped as user group: %q", out)
+	}
+}

+ 31 - 0
pkg/chat/template.go

@@ -0,0 +1,31 @@
+// Package chat provides chat template formatting for LLM prompts
+package chat
+
+import (
+	"strings"
+)
+
+// FormatQwen formats messages for Qwen/ChatML style models
+// Uses  role\ncontent  format
+func FormatQwen(messages []Message, addGenerationPrompt bool) string {
+	out, _ := (&Qwen3Renderer{}).Render(messages, Options{AddGenerationPrompt: addGenerationPrompt})
+	return out
+}
+
+// FormatLlama formats messages for Llama 3 style models
+func FormatLlama(messages []Message, addGenerationPrompt bool) string {
+	out, _ := (&Qwen3Renderer{}).Render(messages, Options{AddGenerationPrompt: addGenerationPrompt})
+	return out
+}
+
+// Format automatically selects format based on model type
+func Format(messages []Message, modelType string, addGenerationPrompt bool) string {
+	switch {
+	case strings.Contains(strings.ToLower(modelType), "qwen"):
+		return FormatQwen(messages, addGenerationPrompt)
+	case strings.Contains(strings.ToLower(modelType), "llama"):
+		return FormatLlama(messages, addGenerationPrompt)
+	default:
+		return FormatQwen(messages, addGenerationPrompt) // Default to ChatML
+	}
+}

+ 37 - 0
pkg/chat/types.go

@@ -0,0 +1,37 @@
+package chat
+
+import "encoding/json"
+
+// Message represents a chat message.
+// Role is typically: "system", "user", "assistant", "tool".
+// ToolCalls is used for assistant messages that request tool/function calls.
+// ReasoningContent optionally carries chain-of-thought style content for templates that support it.
+// Content is the visible message content.
+//
+// Note: We intentionally keep this struct generic and close to common chat template expectations.
+//
+//nolint:revive // exported API
+type Message struct {
+	Role            string     `json:"role"`
+	Content         string     `json:"content"`
+	ReasoningContent string    `json:"reasoning_content,omitempty"`
+	ToolCalls       []ToolCall `json:"tool_calls,omitempty"`
+}
+
+// ToolCall represents a single tool/function call requested by the assistant.
+// Arguments may be either a raw JSON object or a string (some models emit stringified JSON).
+//
+//nolint:revive // exported API
+type ToolCall struct {
+	Name      string          `json:"name"`
+	Arguments json.RawMessage `json:"arguments"`
+}
+
+// Options control chat template rendering.
+//
+//nolint:revive // exported API
+type Options struct {
+	AddGenerationPrompt bool
+	EnableThinking      bool
+	Tools               []any
+}

+ 184 - 0
pkg/compute/activation.go

@@ -0,0 +1,184 @@
+// Package compute provides device-agnostic computation with hybrid CPU/GPU support.
+package compute
+
+import (
+	"fmt"
+
+	"makarna/pkg/backend/cpu"
+	"makarna/pkg/backend/cuda"
+	"makarna/pkg/backend/device"
+	"makarna/pkg/tensor"
+)
+
+// Activation wraps a tensor with device tracking.
+// It enables efficient hybrid execution where transfers only happen
+// when crossing device boundaries.
+type Activation struct {
+	tensor    tensor.Tensor
+	placement tensor.DevicePlacement
+}
+
+// NewActivation creates an activation on the specified device.
+func NewActivation(shape tensor.Shape, placement tensor.DevicePlacement) (*Activation, error) {
+	var t tensor.Tensor
+	var err error
+
+	if placement.Type == tensor.CUDA && device.CUDAAvailable() {
+		t, err = cuda.NewTensor(shape, tensor.Float32, placement.GPU)
+		if err != nil {
+			// Fallback to CPU
+			t = cpu.NewTensor(shape, nil)
+			placement = tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}
+		}
+	} else {
+		t = cpu.NewTensor(shape, nil)
+		placement = tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}
+	}
+
+	return &Activation{tensor: t, placement: placement.Normalize()}, err
+}
+
+// NewActivationFrom wraps an existing tensor.
+func NewActivationFrom(t tensor.Tensor) *Activation {
+	var placement tensor.DevicePlacement
+
+	if ct, ok := t.(*cuda.Tensor); ok {
+		placement = tensor.DevicePlacement{Type: tensor.CUDA, GPU: ct.GPU()}
+	} else {
+		placement = tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}
+	}
+
+	return &Activation{tensor: t, placement: placement.Normalize()}
+}
+
+// Tensor returns the underlying tensor.
+func (a *Activation) Tensor() tensor.Tensor {
+	return a.tensor
+}
+
+// Placement returns the current device placement.
+func (a *Activation) Placement() tensor.DevicePlacement {
+	return a.placement
+}
+
+// IsGPU returns true if the activation is on GPU.
+func (a *Activation) IsGPU() bool {
+	return a.placement.Type == tensor.CUDA
+}
+
+// Shape returns the tensor shape.
+func (a *Activation) Shape() tensor.Shape {
+	return a.tensor.Shape()
+}
+
+// EnsureOn moves the activation to the target device if needed.
+// Returns true if a transfer occurred.
+func (a *Activation) EnsureOn(target tensor.DevicePlacement) (transferred bool, err error) {
+	target = target.Normalize()
+
+	// Already on target device
+	if a.placement == target {
+		return false, nil
+	}
+
+	// Transfer needed
+	newTensor, err := device.EnsureOn(a.tensor, target)
+	if err != nil {
+		return false, fmt.Errorf("activation transfer %v -> %v: %w", a.placement, target, err)
+	}
+
+	// Free old GPU tensor to prevent memory leak
+	if oldCT, ok := a.tensor.(*cuda.Tensor); ok && oldCT != nil {
+		oldCT.Free()
+	}
+
+	a.tensor = newTensor
+	a.placement = target
+	return true, nil
+}
+
+// AsCPU returns the tensor as *cpu.Tensor, transferring if needed.
+func (a *Activation) AsCPU() (*cpu.Tensor, error) {
+	if _, err := a.EnsureOn(tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}); err != nil {
+		return nil, err
+	}
+	return a.tensor.(*cpu.Tensor), nil
+}
+
+// AsCUDA returns the tensor as *cuda.Tensor, transferring if needed.
+func (a *Activation) AsCUDA(gpu int) (*cuda.Tensor, error) {
+	if _, err := a.EnsureOn(tensor.DevicePlacement{Type: tensor.CUDA, GPU: gpu}); err != nil {
+		return nil, err
+	}
+	return a.tensor.(*cuda.Tensor), nil
+}
+
+// ReplaceWith replaces the underlying tensor and updates placement.
+func (a *Activation) ReplaceWith(t tensor.Tensor) {
+	if a.tensor != nil {
+		if oldCT, ok := a.tensor.(*cuda.Tensor); ok {
+			if newCT, ok2 := t.(*cuda.Tensor); ok2 {
+				if oldCT != newCT {
+					oldCT.Free()
+				}
+			} else {
+				oldCT.Free()
+			}
+		}
+	}
+
+	a.tensor = t
+	if ct, ok := t.(*cuda.Tensor); ok {
+		a.placement = tensor.DevicePlacement{Type: tensor.CUDA, GPU: ct.GPU()}
+	} else {
+		a.placement = tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}
+	}
+}
+
+// Clone creates a deep copy of the activation on the same device.
+func (a *Activation) Clone() (*Activation, error) {
+	if a.IsGPU() {
+		ct := a.tensor.(*cuda.Tensor)
+		newT, err := cuda.NewTensor(ct.Shape(), ct.DType(), ct.GPU())
+		if err != nil {
+			return nil, err
+		}
+		// Copy GPU to GPU using CopyToHost then CopyFrom (simple path)
+		tempBuf := make([]float32, ct.Shape().NumElements())
+		if err := ct.CopyToHost(tempBuf); err != nil {
+			return nil, err
+		}
+		if err := newT.CopyFrom(tempBuf); err != nil {
+			return nil, err
+		}
+		return &Activation{tensor: newT, placement: a.placement}, nil
+	}
+
+	// CPU clone
+	src := a.tensor.(*cpu.Tensor)
+	dst := cpu.NewTensor(src.Shape(), nil)
+	copy(dst.DataFloat32(), src.DataFloat32())
+	return &Activation{tensor: dst, placement: a.placement}, nil
+}
+
+// CopyFrom copies data from a CPU tensor to this activation
+func (a *Activation) CopyFrom(t *cpu.Tensor) error {
+	if a.IsGPU() {
+		return a.tensor.(*cuda.Tensor).CopyFrom(t.DataFloat32())
+	}
+	src := t.DataFloat32()
+	dst := a.tensor.(*cpu.Tensor).DataFloat32()
+	copy(dst, src)
+	return nil
+}
+
+// FreeActivation frees GPU memory if the activation is on GPU.
+// Safe to call on nil or CPU activations.
+func FreeActivation(a *Activation) {
+	if a == nil {
+		return
+	}
+	if ct, ok := a.tensor.(*cuda.Tensor); ok && ct != nil {
+		ct.Free()
+	}
+}

+ 122 - 0
pkg/compute/compute.go

@@ -0,0 +1,122 @@
+// Package compute provides device-agnostic computation dispatching.
+// Operations automatically route to the appropriate backend (CPU/CUDA)
+// based on tensor placement, eliminating manual device management in model code.
+package compute
+
+import (
+	"fmt"
+
+	"makarna/pkg/backend/cpu"
+	"makarna/pkg/backend/cuda"
+	"makarna/pkg/backend/device"
+	"makarna/pkg/tensor"
+)
+
+// Context holds computation state for a forward pass.
+type Context struct {
+	Dispatcher *device.DeviceDispatcher
+	LayerIdx   int
+	Scratch    *ScratchSpace
+	CPUMoE     bool // Keep MoE expert weights on CPU
+}
+
+// NewContext creates a computation context.
+func NewContext(dispatcher *device.DeviceDispatcher, layerIdx int) *Context {
+	return &Context{
+		Dispatcher: dispatcher,
+		LayerIdx:   layerIdx,
+	}
+}
+
+// Placement returns the current layer's device placement.
+func (c *Context) Placement() tensor.DevicePlacement {
+	if c.Dispatcher == nil {
+		return tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}
+	}
+	return c.Dispatcher.LayerPlacement(c.LayerIdx)
+}
+
+// IsGPU returns true if current layer is on GPU.
+func (c *Context) IsGPU() bool {
+	return c.Placement().Type == tensor.CUDA
+}
+
+// EnsureWeight ensures a weight tensor is on the correct device with caching.
+func (c *Context) EnsureWeight(t tensor.Tensor, name string) (tensor.Tensor, error) {
+	if c.Dispatcher == nil {
+		return t, nil
+	}
+
+	placement := c.Placement()
+	if placement.Type == tensor.CPU {
+		return t, nil
+	}
+
+	cache := c.Dispatcher.GetWeightCache(placement.GPU)
+	key := fmt.Sprintf("%d:%s", c.LayerIdx, name)
+	return device.EnsureOnCached(t, placement, cache, key)
+}
+
+// EnsureActivation ensures an activation tensor is on the correct device.
+// Unlike weights, activations are not cached between forward passes.
+func (c *Context) EnsureActivation(t tensor.Tensor) (tensor.Tensor, error) {
+	if c.Dispatcher == nil {
+		return t, nil
+	}
+	return device.EnsureOn(t, c.Placement())
+}
+
+// Zeros creates a zero tensor on the appropriate device.
+func Zeros(ctx *Context, shape tensor.Shape) tensor.Tensor {
+	if ctx == nil || !ctx.IsGPU() || !device.CUDAAvailable() {
+		return cpu.NewTensor(shape, nil)
+	}
+
+	t, err := cuda.NewTensor(shape, tensor.Float32, ctx.Placement().GPU)
+	if err != nil {
+		// Fallback to CPU
+		return cpu.NewTensor(shape, nil)
+	}
+	return t
+}
+
+// ZerosCPU always creates a CPU tensor (for inputs/outputs).
+func ZerosCPU(shape tensor.Shape) *cpu.Tensor {
+	return cpu.NewTensor(shape, nil)
+}
+
+// ToCPU copies a tensor to CPU if needed.
+func ToCPU(t tensor.Tensor) (*cpu.Tensor, error) {
+	if cpuT, ok := t.(*cpu.Tensor); ok {
+		return cpuT, nil
+	}
+
+	result, err := device.EnsureOn(t, tensor.DevicePlacement{Type: tensor.CPU, GPU: -1})
+	if err != nil {
+		return nil, err
+	}
+	return result.(*cpu.Tensor), nil
+}
+
+// Copy copies data between tensors, handling cross-device copies.
+func Copy(dst, src tensor.Tensor) error {
+	// Same device, same type
+	if dstCPU, ok := dst.(*cpu.Tensor); ok {
+		if srcCPU, ok := src.(*cpu.Tensor); ok {
+			copy(dstCPU.DataFloat32(), srcCPU.DataFloat32())
+			return nil
+		}
+	}
+
+	if dstCUDA, ok := dst.(*cuda.Tensor); ok {
+		if srcCUDA, ok := src.(*cuda.Tensor); ok {
+			// TODO: CUDA-to-CUDA copy kernel
+			_ = dstCUDA
+			_ = srcCUDA
+			return fmt.Errorf("CUDA-to-CUDA copy not implemented")
+		}
+	}
+
+	// Cross-device: need intermediate copy
+	return fmt.Errorf("cross-device copy requires explicit conversion")
+}

+ 145 - 0
pkg/compute/compute_test.go

@@ -0,0 +1,145 @@
+package compute
+
+import (
+	"testing"
+
+	"makarna/pkg/backend/cpu"
+	"makarna/pkg/backend/device"
+	"makarna/pkg/tensor"
+)
+
+func TestLinearCPU(t *testing.T) {
+	// Input: [2, 3] @ Weight: [4, 3] = Output: [2, 4]
+	input := cpu.NewTensor(tensor.Shape{2, 3}, []float32{
+		1, 2, 3,
+		4, 5, 6,
+	})
+	weight := cpu.NewTensor(tensor.Shape{4, 3}, []float32{
+		1, 0, 0,
+		0, 1, 0,
+		0, 0, 1,
+		1, 1, 1,
+	})
+	output := cpu.NewTensor(tensor.Shape{2, 4}, nil)
+
+	ctx := NewContext(nil, 0) // nil dispatcher = CPU
+
+	if err := Linear(ctx, input, weight, output); err != nil {
+		t.Fatalf("Linear failed: %v", err)
+	}
+
+	expected := []float32{
+		1, 2, 3, 6,  // row 0: [1,2,3] dot each weight row
+		4, 5, 6, 15, // row 1
+	}
+
+	outData := output.DataFloat32()
+	for i, exp := range expected {
+		if diff := outData[i] - exp; diff < -0.001 || diff > 0.001 {
+			t.Errorf("output[%d] = %f, expected %f", i, outData[i], exp)
+		}
+	}
+}
+
+func TestRMSNorm(t *testing.T) {
+	x := cpu.NewTensor(tensor.Shape{1, 4}, []float32{1, 2, 3, 4})
+	w := cpu.NewTensor(tensor.Shape{4}, []float32{1, 1, 1, 1})
+
+	ctx := NewContext(nil, 0)
+	if err := RMSNorm(ctx, x, w, 1e-6); err != nil {
+		t.Fatalf("RMSNorm failed: %v", err)
+	}
+
+	// Check output is normalized
+	data := x.DataFloat32()
+	var ss float32
+	for _, v := range data {
+		ss += v * v
+	}
+	rms := ss / 4
+
+	// After RMSNorm, variance should be close to 1
+	if rms < 0.9 || rms > 1.1 {
+		t.Errorf("RMS after norm = %f, expected ~1.0", rms)
+	}
+}
+
+func TestDeviceDispatcher(t *testing.T) {
+	placements := []tensor.DevicePlacement{
+		{Type: tensor.CUDA, GPU: 0},
+		{Type: tensor.CUDA, GPU: 0},
+		{Type: tensor.CPU, GPU: -1},
+		{Type: tensor.CPU, GPU: -1},
+	}
+
+	dd := device.NewDeviceDispatcher(placements)
+
+	if dd.NumGPULayers() != 2 {
+		t.Errorf("NumGPULayers = %d, expected 2", dd.NumGPULayers())
+	}
+
+	if !dd.IsLayerOnGPU(0) {
+		t.Error("Layer 0 should be on GPU")
+	}
+
+	if dd.IsLayerOnGPU(2) {
+		t.Error("Layer 2 should be on CPU")
+	}
+
+	p := dd.LayerPlacement(1)
+	if p.Type != tensor.CUDA {
+		t.Errorf("Layer 1 placement = %v, expected CUDA", p.Type)
+	}
+
+	// Beyond bounds defaults to CPU
+	p = dd.LayerPlacement(100)
+	if p.Type != tensor.CPU {
+		t.Errorf("Out of bounds placement = %v, expected CPU", p.Type)
+	}
+}
+
+func TestContextPlacement(t *testing.T) {
+	placements := []tensor.DevicePlacement{
+		{Type: tensor.CUDA, GPU: 0},
+		{Type: tensor.CPU, GPU: -1},
+	}
+
+	dd := device.NewDeviceDispatcher(placements)
+
+	ctx0 := NewContext(dd, 0)
+	if !ctx0.IsGPU() {
+		t.Error("Context 0 should be GPU")
+	}
+
+	ctx1 := NewContext(dd, 1)
+	if ctx1.IsGPU() {
+		t.Error("Context 1 should be CPU")
+	}
+
+	// Nil dispatcher
+	ctxNil := NewContext(nil, 0)
+	if ctxNil.IsGPU() {
+		t.Error("Nil dispatcher should default to CPU")
+	}
+}
+
+func TestSwiGLU(t *testing.T) {
+	gate := cpu.NewTensor(tensor.Shape{2}, []float32{0, 1})
+	up := cpu.NewTensor(tensor.Shape{2}, []float32{2, 3})
+	out := cpu.NewTensor(tensor.Shape{2}, nil)
+
+	ctx := NewContext(nil, 0)
+	if err := SwiGLU(ctx, gate, up, out); err != nil {
+		t.Fatalf("SwiGLU failed: %v", err)
+	}
+
+	// SiLU(0) = 0, so out[0] = 0 * 2 = 0
+	// SiLU(1) ≈ 0.731, so out[1] ≈ 0.731 * 3 ≈ 2.19
+	data := out.DataFloat32()
+	if data[0] != 0 {
+		t.Errorf("out[0] = %f, expected 0", data[0])
+	}
+	if data[1] < 2.0 || data[1] > 2.5 {
+		t.Errorf("out[1] = %f, expected ~2.2", data[1])
+	}
+}

この差分においてかなりの量のファイルが変更されているため、一部のファイルを表示していません