Piotr Wilkin 3 месяцев назад
Родитель
Сommit
22ee5a971b

+ 135 - 0
CACHE_STATS_README.md

@@ -0,0 +1,135 @@
+# Cache Statistics Feature for llama.cpp
+
+This document describes the cache statistics functionality added to llama.cpp for debugging and analyzing the recurrent cache behavior in models like Qwen3 Next.
+
+## Overview
+
+The cache statistics feature allows users to dump detailed information about the model's cache state after each token generation. This is particularly useful for:
+
+- Understanding how the recurrent cache evolves during inference
+- Debugging cache-related issues in hybrid models (attention + recurrent)
+- Analyzing memory usage patterns
+- Comparing cache behavior between different models
+
+## Usage
+
+### Command Line Option
+
+Add the `--dump-cache` flag to any llama.cpp command to enable cache statistics printing:
+
+```bash
+./llama-cli -m your_model.gguf -p "Hello, my name is" -n 10 --dump-cache
+```
+
+### Test Script
+
+A convenient test script is provided:
+
+```bash
+./test_cache_stats.sh /path/to/model.gguf "Your prompt here"
+```
+
+## Output Format
+
+When enabled, the cache statistics are printed after each token generation:
+
+```
+=== CACHE STATISTICS FOR TOKEN 1 ===
+Model has 32 layers
+Memory address: 0x555555555555
+Sequence 0: pos_min=0, pos_max=5, length=6
+Memory supports shifting: true
+
+Layer-by-layer cache information:
+Note: Detailed tensor statistics require internal API access
+This framework shows where conv/state/recurrent cache data would be displayed
+
+Layer 0:
+  Conv State: [sum=N/A, mean=N/A] (shape=N/A)
+  Recurrent State: [sum=N/A, mean=N/A] (shape=N/A)
+  Key Cache: [sum=N/A, mean=N/A] (shape=N/A)
+  Value Cache: [sum=N/A, mean=N/A] (shape=N/A)
+
+...
+
+To access actual cache statistics, the following would be needed:
+1. Internal API access to llama_memory_hybrid::get_mem_recr()
+2. Access to llama_memory_recurrent::get_r_l() and ::get_s_l() tensors
+3. Access to llama_kv_cache tensors for attention layers
+4. ggml_tensor data access for sum/mean calculations
+=============================================
+```
+
+## Implementation Details
+
+### Files Modified
+
+1. **tools/main/main.cpp**: Added cache statistics printing functionality
+2. **common/common.h**: Added `dump_cache` parameter to `common_params` struct
+3. **common/arg.cpp**: Added `--dump-cache` command line argument parsing
+
+### Key Functions
+
+- `print_cache_statistics()`: Main function that prints cache information
+- Uses public llama.cpp APIs where available
+- Provides framework for accessing internal cache data
+
+### Limitations
+
+The current implementation provides a framework for cache statistics but has limitations due to the public API constraints:
+
+1. **Tensor Data Access**: Cannot directly access tensor data (sum, mean) without internal APIs
+2. **Layer Type Detection**: Cannot distinguish between attention and recurrent layers
+3. **Cache Type Identification**: Limited ability to determine specific cache types
+
+### Future Enhancements
+
+To fully implement cache statistics with actual tensor data, the following would be needed:
+
+1. **Internal API Access**: Friend class access or new public APIs for cache internals
+2. **Tensor Data Access**: Methods to access ggml_tensor data for calculations
+3. **Layer Type Information**: APIs to determine layer types (attention vs recurrent)
+4. **Cache Statistics Methods**: Built-in methods for cache statistics calculation
+
+## Comparison with Python Reference
+
+The Python reference implementation in `reference/tests/cache_stats_qwen3_next.py` provides full access to:
+
+- Convolution state tensors (conv_states)
+- Recurrent state tensors (recurrent_states)  
+- Key/value cache tensors
+- Actual sum and mean calculations
+
+The C++ implementation aims to provide similar functionality once the necessary internal APIs are available.
+
+## Troubleshooting
+
+### No Cache Statistics Visible
+
+If cache statistics don't appear:
+1. Ensure `--dump-cache` flag is used
+2. Check that the model supports cache operations
+3. Verify the model is loaded correctly
+
+### Memory Address Shows as Null
+
+This indicates no memory is allocated for the cache, which could mean:
+- Model doesn't support caching
+- Memory allocation failed
+- Incorrect model type
+
+## Development Notes
+
+For developers wanting to extend this functionality:
+
+1. **Internal Access**: The main limitation is accessing internal cache structures
+2. **API Design**: Consider adding public APIs for cache statistics
+3. **Performance**: Cache statistics printing should have minimal performance impact
+4. **Thread Safety**: Ensure thread safety when accessing cache data
+
+## Related Files
+
+- `reference/tests/cache_stats_qwen3_next.py`: Python reference implementation
+- `src/llama-memory-hybrid.h`: Hybrid memory structure definitions
+- `src/llama-memory-recurrent.h`: Recurrent memory structure definitions
+- `src/llama-kv-cache.h`: KV cache structure definitions

+ 7 - 0
common/arg.cpp

@@ -1655,6 +1655,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.kv_unified = true;
             params.kv_unified = true;
         }
         }
     ).set_env("LLAMA_ARG_KV_SPLIT"));
     ).set_env("LLAMA_ARG_KV_SPLIT"));
+    add_opt(common_arg(
+        {"--dump-cache"},
+        "dump cache statistics after each token generation",
+        [](common_params & params) {
+            params.dump_cache = true;
+        }
+    ).set_examples({LLAMA_EXAMPLE_MAIN}));
     add_opt(common_arg(
     add_opt(common_arg(
         {"--no-context-shift"},
         {"--no-context-shift"},
         string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
         string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),

+ 2 - 0
common/common.h

@@ -397,6 +397,8 @@ struct common_params {
 
 
     ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
     ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
     ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
     ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
+    
+    bool dump_cache = false; // dump cache statistics after each token
 
 
     common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO;
     common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO;
 
 

+ 3 - 0
examples/model-conversion/qwen3stories.sh

@@ -0,0 +1,3 @@
+export MODEL_PATH=/devel/tools/llama.cpp/reference/theo77186_Qwen3-Next-70M-TinyStories
+export CONVERTED_MODEL=/devel/tools/llama.cpp/reference/theo77186_Qwen3-Next-70M-TinyStories/theo77186_Qwen3-Next-70M-TinyStories.gguf
+make causal-verify-logits

+ 7 - 1
examples/model-conversion/scripts/causal/run-converted-model.sh

@@ -4,6 +4,11 @@ set -e
 
 
 # First try command line argument, then environment variable, then file
 # First try command line argument, then environment variable, then file
 CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
 CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
+MODEL_TESTING_PROMPT="${2:-"$MODEL_TESTING_PROMPT"}"
+
+if [ -z "$MODEL_TESTING_PROMPT"]; then
+    MODEL_TESTING_PROMPT="Hello, my name is"
+fi
 
 
 # Final check if we have a model path
 # Final check if we have a model path
 if [ -z "$CONVERTED_MODEL" ]; then
 if [ -z "$CONVERTED_MODEL" ]; then
@@ -14,7 +19,8 @@ if [ -z "$CONVERTED_MODEL" ]; then
 fi
 fi
 
 
 echo $CONVERTED_MODEL
 echo $CONVERTED_MODEL
+echo $MODEL_TESTING_PROMPT
 
 
 cmake --build ../../build --target llama-logits -j8
 cmake --build ../../build --target llama-logits -j8
 
 
-../../build/bin/llama-logits -m "$CONVERTED_MODEL" "Hello, my name is"
+../../build/bin/llama-logits -m "$CONVERTED_MODEL" "$MODEL_TESTING_PROMPT"

+ 586 - 0
examples/model-conversion/scripts/causal/run-org-model-multi-token.py

@@ -0,0 +1,586 @@
+#!/usr/bin/env python3
+
+import argparse
+import os
+import importlib
+from pathlib import Path
+import re
+
+from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
+import torch
+import numpy as np
+
+### If you want to dump RoPE activations, apply this monkey patch to the model
+### class from Transformers that you are running (replace apertus.modeling_apertus
+### with the proper package and class for your model
+### === START ROPE DEBUG ===
+# from transformers.models.apertus.modeling_apertus import apply_rotary_pos_emb
+
+# orig_rope = apply_rotary_pos_emb
+# torch.set_printoptions(threshold=float('inf'))
+# torch.set_printoptions(precision=6, sci_mode=False)
+
+# def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+#     # log inputs
+#     summarize(q, "RoPE.q_in")
+#     summarize(k, "RoPE.k_in")
+
+#     # call original
+#     q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
+
+#     # log outputs
+#     summarize(q_out, "RoPE.q_out")
+#     summarize(k_out, "RoPE.k_out")
+
+#     return q_out, k_out
+
+# # Patch it
+# import transformers.models.apertus.modeling_apertus as apertus_mod  # noqa: E402
+# apertus_mod.apply_rotary_pos_emb = debug_rope
+### == END ROPE DEBUG ===
+
+token_counter = {}
+
+def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3):
+    global token, token_counter
+    """
+    Print a tensor in llama.cpp debug style.
+
+    Supports:
+    - 2D tensors (seq, hidden)
+    - 3D tensors (batch, seq, hidden)
+    - 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head
+
+    Shows first and last max_vals of each vector per sequence position.
+    """
+    t = tensor.detach().to(torch.float32).cpu()
+
+    # Determine dimensions
+    if t.ndim == 3:
+        _, s, _ = t.shape
+    elif t.ndim == 2:
+        _, s = 1, t.shape[0]
+        t = t.unsqueeze(0)
+    elif t.ndim == 4:
+        _, s, _, _ = t.shape
+    else:
+        print(f"Skipping tensor due to unsupported dimensions: {t.ndim}")
+        return
+
+    ten_shape = t.shape
+
+    print(f"ggml_debug: {name} = (f32)  ... = {{{ten_shape}}}")
+    print("                                     [")
+    print("                                      [")
+
+    # Determine indices for first and last sequences
+    first_indices = list(range(min(s, max_seq)))
+    last_indices = list(range(max(0, s - max_seq), s))
+
+    # Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq
+    has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s)
+
+    # Combine indices
+    if has_overlap:
+        # If there's overlap, just use the combined unique indices
+        indices = sorted(list(set(first_indices + last_indices)))
+        separator_index = None
+    else:
+        # If no overlap, we'll add a separator between first and last sequences
+        indices = first_indices + last_indices
+        separator_index = len(first_indices)
+
+    for i, si in enumerate(indices):
+        # Add separator if needed
+        if separator_index is not None and i == separator_index:
+            print("                                       ...")
+
+        # Extract appropriate slice
+        vec = t[0, si]
+        if vec.ndim == 2:  # 4D case: flatten heads × dim_per_head
+            flat = vec.flatten().tolist()
+        else:  # 2D or 3D case
+            flat = vec.tolist()
+
+        # First and last slices
+        first = flat[:max_vals]
+        last = flat[-max_vals:] if len(flat) >= 2 * max_vals else flat
+        first_str = ", ".join(f"{v:12.4f}" for v in first)
+        last_str = ", ".join(f"{v:12.4f}" for v in last)
+
+        if len(flat) >= 2 * max_vals:
+            print(f"                                       [{first_str}, ..., {last_str}]")
+        else:
+            print(f"                                       [{last_str}]")
+
+    print("                                      ],")
+    print("                                     ]")
+    print(f"                                     sum = {t.sum().item():.6f}\n")
+
+    pattern = r"model\.layers\.[0-9]+_out"
+    if re.fullmatch(pattern, name):
+        if name not in token_counter:
+            token_counter[name] = 1
+        else:
+            token_counter[name] = token_counter[name] + 1
+        save_tensor(t, f"reference/tensors/org/{name}_{token_counter[name]}.bin")
+
+from transformers.models.qwen3_next.modeling_qwen3_next import torch_causal_conv1d_update, apply_rotary_pos_emb  # noqa: E402
+orig_conv1d_update = torch_causal_conv1d_update
+orig_rope = apply_rotary_pos_emb
+import torch.nn.functional as F  # noqa: E402
+import typing  # noqa: E402
+
+def patched_torch_causal_conv1d_update(
+    hidden_states,
+    conv_state,
+    weight,
+    bias=None,
+    activation=None,
+):
+    _, hidden_size, seq_len = hidden_states.shape
+    state_len = conv_state.shape[-1]
+    summarize(hidden_states, "hidden_states_in")
+    summarize(conv_state, "conv_state_in")
+
+    hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype)
+    summarize(hidden_states_new, "hidden_states_new")
+    summarize(hidden_states_new[:, :, -state_len:], "hidden_states_to_copy")
+    summarize(conv_state, "conv_state_pre")
+    conv_state.copy_(hidden_states_new[:, :, -state_len:])
+    summarize(conv_state, "conv_state_post")
+    out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size)
+    summarize(out, "out")
+    summarize(out[:, :, -seq_len:], "out_proper")
+    out = F.silu(out[:, :, -seq_len:])
+    summarize(out, "out_silu")
+    out = out.to(hidden_states.dtype)
+    return out
+
+already_dumped_rope = False
+
+def save_tensor(tensor, filename):
+    """Save tensor to binary file with shape information."""
+    # Ensure tensors directory exists
+    os.makedirs(os.path.dirname(filename), exist_ok=True)
+    
+    # Convert to numpy and save
+    np_array = tensor.detach().cpu().numpy()
+    
+    # Save shape first (4 int64 values), then data
+    with open(filename, 'wb') as f:
+        shape = list(np_array.shape)
+        while len(shape) < 4:
+            shape.insert(0, 0)
+        
+        # Write shape as int64
+        shape_array = np.array(shape, dtype=np.int64)
+        f.write(shape_array.tobytes())
+        
+        # Write data as float32
+        np_array_float32 = np_array.astype(np.float32)
+        f.write(np_array_float32.tobytes())
+
+def patched_apply_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+    global already_dumped_rope
+
+    # log inputs
+    summarize(q, "RoPE.q_in")
+    summarize(k, "RoPE.k_in")
+    summarize(cos, "cos")
+    summarize(sin, "sin")
+    if q.shape[1] == 2 and k.shape[1] == 1 and k.shape[2] == 1 and not already_dumped_rope:
+        already_dumped_rope = True
+        print("Dumping input tensors")
+        save_tensor(q, "reference/tensors/testrope_q_in.bin")
+        save_tensor(k, "reference/tensors/testrope_k_in.bin")
+        save_tensor(cos, "reference/tensors/testrope_cos_in.bin")
+        save_tensor(sin, "reference/tensors/testrope_sin_in.bin")
+
+    if position_ids:
+        summarize(position_ids, "position_ids")
+    print(f"Rotary dim is {cos.unsqueeze(unsqueeze_dim).shape[-1]}")
+
+    # call original
+    q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
+
+    # log outputs
+    summarize(q_out, "RoPE.q_out")
+    summarize(k_out, "RoPE.k_out")
+
+    return q_out, k_out
+
+import transformers.models.qwen3_next.modeling_qwen3_next as qwen_mod  # noqa: E402
+qwen_mod.torch_causal_conv1d_update = patched_torch_causal_conv1d_update
+qwen_mod.apply_rotary_pos_emb = patched_apply_rope
+
+# Store original functions for patching
+original_functions = {}
+
+def debug_hook(name):
+    def fn(_m, input, output):
+        if isinstance(input, torch.Tensor):
+            summarize(input, name + "_in")
+        elif isinstance(input, (tuple, list)) and len(input) > 0 and isinstance(input[0], torch.Tensor):
+            summarize(input[0], name + "_in")
+        if isinstance(output, torch.Tensor):
+            summarize(output, name + "_out")
+        elif isinstance(output, (tuple, list)) and len(output) > 0 and isinstance(output[0], torch.Tensor):
+            summarize(output[0], name + "_out")
+
+    return fn
+
+def patch_all_forward_methods(model):
+    """Apply monkey patches to all forward methods in the model"""
+    for name, module in model.named_modules():
+        # Set layer index if applicable
+        parts = name.split('.')
+        module.layer_idx = -1 # Default invalid value
+
+        if len(parts) > 2 and parts[0] == 'model' and parts[1] == 'layers':
+            try:
+                module.layer_idx = int(parts[2])  # Convert to integer
+            except (ValueError, IndexError):
+                module.layer_idx = -1
+
+        # Apply forward hook to log all inputs/outputs
+        module.register_forward_hook(debug_hook(name))
+
+        # Additional patches for specific methods in various modules
+        if hasattr(module, 'forward'):
+            original_forward = module.forward
+            def make_patched_forward(orig_forward, mod_name):
+                def patched_forward(*args, **kwargs):
+                    # Log inputs
+                    for i, arg in enumerate(args):
+                        if isinstance(arg, torch.Tensor):
+                            summarize(arg, f"{mod_name}.forward.arg_{i}_in")
+
+                    # Call original forward
+                    result = orig_forward(*args, **kwargs)
+
+                    # Log output
+                    if isinstance(result, torch.Tensor):
+                        summarize(result, f"{mod_name}.forward.out")
+                    elif isinstance(result, (tuple, list)):
+                        for i, res in enumerate(result):
+                            if isinstance(res, torch.Tensor):
+                                summarize(res, f"{mod_name}.forward.out_{i}")
+
+                    return result
+                return patched_forward
+
+            module.forward = make_patched_forward(original_forward, name)
+
+def patch_silu():
+    """Patch torch.nn.functional.silu to log inputs and outputs"""
+    global original_functions
+
+    if 'silu' not in original_functions:
+        original_functions['silu'] = torch.nn.functional.silu
+
+    def patched_silu(input, inplace=False):
+        # Log input
+        summarize(input, "silu_in")
+
+        # Call original function
+        result = original_functions['silu'](input, inplace)
+
+        # Log output
+        summarize(result, "silu_out")
+
+        return result
+
+    # Replace the function in the torch.nn.functional module
+    torch.nn.functional.silu = patched_silu
+
+def patch_sigmoid():
+    """Patch torch.nn.functional.sigmoid to log inputs and outputs"""
+    global original_functions
+
+    if 'sigmoid' not in original_functions:
+        original_functions['sigmoid'] = torch.nn.functional.sigmoid
+
+    def patched_sigmoid(input):
+        # Log input
+        summarize(input, "sigmoid_in")
+
+        # Call original function
+        result = original_functions['sigmoid'](input)
+
+        # Log output
+        summarize(result, "sigmoid_out")
+
+        return result
+
+    # Replace the function in the torch.nn.functional module
+    torch.nn.functional.sigmoid = patched_sigmoid
+
+
+def patch_torch_sigmoid():
+    """Patch torch.nn.functional.sigmoid to log inputs and outputs"""
+    global original_functions
+
+    if 'torch_sigmoid' not in original_functions:
+        original_functions['torch_sigmoid'] = torch.sigmoid
+
+    def patched_torch_sigmoid(input):
+        # Log input
+        summarize(input, "torch_sigmoid_in")
+
+        # Call original function
+        result = original_functions['torch_sigmoid'](input)
+
+        # Log output
+        summarize(result, "torch_sigmoid_out")
+
+        return result
+
+    # Replace the function in the torch.nn.functional module
+    torch.sigmoid = patched_torch_sigmoid
+
+
+def patch_pad():
+    """Patch torch.nn.functional.pad to log inputs and outputs"""
+    global original_functions
+
+    if 'pad' not in original_functions:
+        original_functions['pad'] = torch.nn.functional.pad
+
+    def patched_pad(input: torch.Tensor, pad: typing.Sequence[int], mode: str = 'constant', value: float | None = None): # pyright: ignore[reportGeneralTypeIssues]
+        # Log input
+        summarize(input, "pad_in")
+        print(f"Padding shape is {pad}")
+
+        # Call original function
+        result = original_functions['pad'](input=input, pad=pad, mode=mode, value=value)
+
+        # Log output
+        summarize(result, "pad_out")
+
+        return result
+
+    # Replace the function in the torch.nn.functional module
+    torch.nn.functional.pad = patched_pad
+
+
+def save_kv_cache(past_key_values, step_num, data_dir, model_name):
+    """Save KV cache tensors for each layer"""
+    cache_dir = data_dir / f"kv_cache_step_{step_num}"
+    cache_dir.mkdir(exist_ok=True)
+    
+    # Access past_key_values if available
+    if past_key_values is not None:
+        for layer_idx, cache_tuple in enumerate(past_key_values):
+            if cache_tuple is None:
+                print(f"Cache tuple is None for layer {layer_idx} at step {step_num}")
+                continue
+                
+            # Handle different cache formats
+            if isinstance(cache_tuple, (tuple, list)) and len(cache_tuple) >= 2:
+                key, value = cache_tuple[0], cache_tuple[1]
+                
+                # Check if key and value are not None
+                if key is not None and value is not None:
+                    # Save key cache
+                    key_filename = cache_dir / f"layer_{layer_idx}_key.bin"
+                    key.detach().cpu().numpy().astype(np.float32).tofile(key_filename)
+                    
+                    # Save value cache
+                    value_filename = cache_dir / f"layer_{layer_idx}_value.bin"
+                    value.detach().cpu().numpy().astype(np.float32).tofile(value_filename)
+                    
+                    print(f"Saved KV cache for layer {layer_idx} at step {step_num}: key.shape={key.shape}, value.shape={value.shape}")
+                else:
+                    print(f"Key or value is None for layer {layer_idx} at step {step_num}")
+            else:
+                # Handle other cache formats (e.g., recurrent models)
+                print(f"Non-standard cache format for layer {layer_idx} at step {step_num}: {type(cache_tuple)}")
+                # Save as generic cache if it's a tensor
+                if hasattr(cache_tuple, 'detach'):
+                    cache_filename = cache_dir / f"layer_{layer_idx}_cache.bin"
+                    cache_tuple.detach().cpu().numpy().astype(np.float32).tofile(cache_filename)
+                    print(f"Saved generic cache for layer {layer_idx} at step {step_num}: shape={cache_tuple.shape}")
+    else:
+        print(f"No KV cache available at step {step_num}")
+
+
+unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
+
+parser = argparse.ArgumentParser(description="Process model with specified path")
+parser.add_argument("--model-path", "-m", help="Path to the model")
+parser.add_argument("--num-tokens", "-n", type=int, default=5, help="Number of tokens to generate")
+parser.add_argument("--prompt", "-p", default="Hello, my name is", help="Input prompt")
+parser.add_argument("--save-cache", action="store_true", help="Save KV cache at each step")
+args = parser.parse_args()
+
+model_path = os.environ.get("MODEL_PATH", args.model_path)
+if model_path is None:
+    parser.error(
+        "Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
+    )
+
+config = AutoConfig.from_pretrained(model_path)
+
+print("Model type:       ", config.model_type)
+print("Vocab size:       ", config.vocab_size)
+print("Hidden size:      ", config.hidden_size)
+print("Number of layers: ", config.num_hidden_layers)
+print("BOS token id:     ", config.bos_token_id)
+print("EOS token id:     ", config.eos_token_id)
+
+print("Loading model and tokenizer using AutoTokenizer:", model_path)
+tokenizer = AutoTokenizer.from_pretrained(model_path)
+config = AutoConfig.from_pretrained(model_path)
+
+if unreleased_model_name:
+    model_name_lower = unreleased_model_name.lower()
+    unreleased_module_path = (
+        f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
+    )
+    class_name = f"{unreleased_model_name}ForCausalLM"
+    print(f"Importing unreleased model module: {unreleased_module_path}")
+
+    try:
+        model_class = getattr(
+            importlib.import_module(unreleased_module_path), class_name
+        )
+        model = model_class.from_pretrained(
+            model_path
+        )  # Note: from_pretrained, not fromPretrained
+    except (ImportError, AttributeError) as e:
+        print(f"Failed to import or load model: {e}")
+        exit(1)
+else:
+    model = AutoModelForCausalLM.from_pretrained(
+        model_path, device_map="auto", offload_folder="offload"
+    )
+
+patch_all_forward_methods(model)
+patch_silu()
+patch_pad()
+patch_sigmoid()
+patch_torch_sigmoid()
+
+model_name = os.path.basename(model_path)
+# Printing the Model class to allow for easier debugging. This can be useful
+# when working with models that have not been publicly released yet and this
+# migth require that the concrete class is imported and used directly instead
+# of using AutoModelForCausalLM.
+print(f"Model class: {model.__class__.__name__}")
+
+device = next(model.parameters()).device
+prompt = args.prompt
+input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
+
+print(f"Input tokens: {input_ids}")
+print(f"Input text: {repr(prompt)}")
+print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
+
+data_dir = Path("data")
+data_dir.mkdir(exist_ok=True)
+
+# Store all generated tokens and logits
+all_generated_tokens = []
+all_logits = []
+
+with torch.no_grad():
+    # Initial forward pass
+    print(f"\n=== Initial Forward Pass ===")
+    outputs = model(input_ids, use_cache=True)
+    logits = outputs.logits
+    
+    # Extract logits for the last token (next token prediction)
+    last_logits = logits[0, -1, :].cpu().numpy()
+    all_logits.append(last_logits)
+    
+    print(f"Logits shape: {logits.shape}")
+    print(f"Last token logits shape: {last_logits.shape}")
+    
+    # Generate first token
+    next_token_id = np.argmax(last_logits).item()
+    all_generated_tokens.append(next_token_id)
+    
+    # Show top 5 predicted tokens for first step
+    top_indices = np.argsort(last_logits)[-5:][::-1]
+    print("Top 5 predictions for first token:")
+    for idx in top_indices:
+        token = tokenizer.decode([idx])
+        print(f"  Token {idx} ({repr(token)}): {last_logits[idx]:.6f}")
+    
+    print(f"Generated token {next_token_id} ({repr(tokenizer.decode([next_token_id]))})")
+    
+    # Save KV cache if requested
+    if args.save_cache:
+        save_kv_cache(outputs.past_key_values, 0, data_dir, model_name)
+    
+    # Prepare for next iteration
+    past_key_values = outputs.past_key_values
+    current_input = torch.tensor([[next_token_id]], device=device)
+    
+    # Generate remaining tokens
+    for step in range(1, args.num_tokens):
+        print(f"\n=== Generation Step {step} ===")
+        
+        # Forward pass with cache
+        outputs = model(
+            input_ids=current_input, 
+            past_key_values=past_key_values,
+            use_cache=True
+        )
+        
+        logits = outputs.logits
+        last_logits = logits[0, -1, :].cpu().numpy()
+        all_logits.append(last_logits)
+        
+        # Generate next token
+        next_token_id = np.argmax(last_logits).item()
+        all_generated_tokens.append(next_token_id)
+        
+        # Show top 5 predicted tokens for this step
+        top_indices = np.argsort(last_logits)[-5:][::-1]
+        print(f"Top 5 predictions for step {step}:")
+        for idx in top_indices:
+            token = tokenizer.decode([idx])
+            print(f"  Token {idx} ({repr(token)}): {last_logits[idx]:.6f}")
+        
+        print(f"Generated token {next_token_id} ({repr(tokenizer.decode([next_token_id]))})")
+        
+        # Save KV cache if requested
+        if args.save_cache:
+            save_kv_cache(outputs.past_key_values, step, data_dir, model_name)
+        
+        # Update for next iteration
+        past_key_values = outputs.past_key_values
+        current_input = torch.tensor([[next_token_id]], device=device)
+
+# Save results
+bin_filename = data_dir / f"pytorch-{model_name}-multi-token.bin"
+txt_filename = data_dir / f"pytorch-{model_name}-multi-token.txt"
+
+# Save all logits concatenated
+all_logits_array = np.array(all_logits)
+all_logits_array.astype(np.float32).tofile(bin_filename)
+
+# Also save as text file for easy inspection
+with open(txt_filename, "w") as f:
+    f.write(f"Generated tokens: {all_generated_tokens}\n")
+    f.write(f"Generated text: {repr(tokenizer.decode(all_generated_tokens))}\n")
+    f.write(f"Full sequence: {repr(tokenizer.decode(input_ids[0].tolist() + all_generated_tokens))}\n\n")
+    
+    for step, logits in enumerate(all_logits):
+        f.write(f"=== Step {step} logits ===\n")
+        for i, logit in enumerate(logits):
+            f.write(f"{i}: {logit:.6f}\n")
+        f.write("\n")
+
+print(f"\n=== Generation Complete ===")
+print(f"Generated {len(all_generated_tokens)} tokens: {all_generated_tokens}")
+print(f"Generated text: {repr(tokenizer.decode(all_generated_tokens))}")
+print(f"Full sequence: {repr(tokenizer.decode(input_ids[0].tolist() + all_generated_tokens))}")
+
+print(f"Saved bin logits to: {bin_filename}")
+print(f"Saved txt logits to: {txt_filename}")
+
+if args.save_cache:
+    print(f"KV cache saved to: {data_dir}/kv_cache_step_*")

+ 4 - 1
examples/model-conversion/scripts/causal/run-org-model.py

@@ -186,7 +186,10 @@ model_name = os.path.basename(model_path)
 print(f"Model class: {model.__class__.__name__}")
 print(f"Model class: {model.__class__.__name__}")
 
 
 device = next(model.parameters()).device
 device = next(model.parameters()).device
-prompt = "Hello, my name is"
+if os.getenv("MODEL_TESTING_PROMPT"):
+    prompt = os.getenv("MODEL_TESTING_PROMPT")
+else:
+    prompt = "Hello, my name is"
 input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
 input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
 
 
 print(f"Input tokens: {input_ids}")
 print(f"Input tokens: {input_ids}")

+ 1 - 0
pyrightconfig.json

@@ -6,6 +6,7 @@
   "reportDuplicateImport": "error",
   "reportDuplicateImport": "error",
   "reportDeprecated": "warning",
   "reportDeprecated": "warning",
   "reportUnnecessaryTypeIgnoreComment": "information",
   "reportUnnecessaryTypeIgnoreComment": "information",
+  "reportAttributeAccessIssue": "warning",
   "disableBytesTypePromotions": false, // TODO: change once Python 3.12 is the minimum
   "disableBytesTypePromotions": false, // TODO: change once Python 3.12 is the minimum
   "executionEnvironments": [
   "executionEnvironments": [
     {
     {

+ 1 - 1
src/llama-graph.cpp

@@ -1532,7 +1532,7 @@ ggml_tensor * llm_graph_context::build_attn(
 
 
     if (wo) {
     if (wo) {
         cur = build_lora_mm(wo, cur);
         cur = build_lora_mm(wo, cur);
-        if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
+        if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_QWEN3NEXT) {
             // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
             // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
             ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
             ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
         }
         }

+ 6 - 2
src/models/llm_build_qwen3next.cpp

@@ -102,7 +102,7 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
         cur = build_layer_ffn(attn_post_norm, model, il);
         cur = build_layer_ffn(attn_post_norm, model, il);
         cb(cur, "ffn_out", il);
         cb(cur, "ffn_out", il);
         
         
-        // Residual connection for FFN - add to the tensor BEFORE post_attention_layernorm
+        // Residual connection for FFN - add to the tensor from before post_attention_layernorm
         cur = ggml_add(ctx0, cur, ffn_residual);
         cur = ggml_add(ctx0, cur, ffn_residual);
         cb(cur, "post_moe", il);
         cb(cur, "post_moe", il);
 
 
@@ -198,9 +198,13 @@ struct ggml_tensor * llm_build_qwen3next::build_qwen3next_attention_layer(ggml_t
     const float kq_scale =
     const float kq_scale =
         hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
         hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
     cur = build_attn(inp_attn, nullptr, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
     cur = build_attn(inp_attn, nullptr, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
+    cb(cur, "attn_pregate", il);
+
+    struct ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate);
+    cb(gate_sigmoid, "gate_sigmoid", il);
 
 
     // Apply gating directly using the original gate tensor
     // Apply gating directly using the original gate tensor
-    cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate));
+    cur = ggml_mul(ctx0, cur, gate_sigmoid);
     cb(cur, "attn_gated", il);
     cb(cur, "attn_gated", il);
 
 
     cur = build_lora_mm(model.layers[il].wo, cur);
     cur = build_lora_mm(model.layers[il].wo, cur);

+ 37 - 0
test_cache_stats.sh

@@ -0,0 +1,37 @@
+#!/bin/bash
+
+# Test script for cache statistics functionality
+# This script demonstrates how to use the --dump-cache flag
+
+echo "Testing llama.cpp cache statistics functionality"
+echo "=============================================="
+
+# Check if a model path is provided
+if [ $# -eq 0 ]; then
+    echo "Usage: $0 <path_to_model.gguf> [prompt]"
+    echo "Example: $0 /path/to/qwen3-next.gguf \"Hello, my name is\""
+    exit 1
+fi
+
+MODEL_PATH="$1"
+PROMPT="${2:-Hello, my name is}"
+
+echo "Model: $MODEL_PATH"
+echo "Prompt: $PROMPT"
+echo ""
+
+# Run llama.cpp with cache statistics enabled
+echo "Running: ./llama-cli -m $MODEL_PATH -p \"$PROMPT\" -n 5 --dump-cache"
+echo ""
+
+# Build the command
+CMD="./build/bin/llama-cli -m $MODEL_PATH -p \"$PROMPT\" -n 5 --dump-cache"
+
+# Execute the command
+echo "Executing: $CMD"
+echo ""
+eval $CMD
+
+echo ""
+echo "Cache statistics test completed."
+echo "=============================================="

+ 181 - 0
tools/main/main.cpp

@@ -15,6 +15,10 @@
 #include <string>
 #include <string>
 #include <vector>
 #include <vector>
 
 
+// Forward declarations for internal cache access
+struct llama_memory_hybrid;
+struct llama_memory_recurrent;
+
 #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
 #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
 #include <signal.h>
 #include <signal.h>
 #include <unistd.h>
 #include <unistd.h>
@@ -40,6 +44,8 @@ static std::ostringstream       * g_output_ss;
 static std::vector<llama_token> * g_output_tokens;
 static std::vector<llama_token> * g_output_tokens;
 static bool is_interacting  = false;
 static bool is_interacting  = false;
 static bool need_insert_eot = false;
 static bool need_insert_eot = false;
+static bool print_cache_stats = false;
+static int token_count = 0;
 
 
 static void print_usage(int argc, char ** argv) {
 static void print_usage(int argc, char ** argv) {
     (void) argc;
     (void) argc;
@@ -83,12 +89,181 @@ static void sigint_handler(int signo) {
 }
 }
 #endif
 #endif
 
 
+struct callback_data {
+    std::vector<uint8_t> data;
+    std::map<std::string, int32_t> tensors;
+};
+
+
+static std::string ggml_ne_string(const ggml_tensor * t) {
+    std::string str;
+    for (int i = 0; i < GGML_MAX_DIMS; ++i) {
+        str += std::to_string(t->ne[i]);
+        if (i + 1 < GGML_MAX_DIMS) {
+            str += ", ";
+        }
+    }
+    return str;
+}
+
+static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
+    union {
+        float f;
+        uint32_t i;
+    } u;
+    u.i = (uint32_t)h.bits << 16;
+    return u.f;
+}
+
+static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) {
+    size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
+    float v;
+    if (type == GGML_TYPE_F16) {
+        v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
+    } else if (type == GGML_TYPE_F32) {
+        v = *(float *) &data[i];
+    } else if (type == GGML_TYPE_I64) {
+        v = (float) *(int64_t *) &data[i];
+    } else if (type == GGML_TYPE_I32) {
+        v = (float) *(int32_t *) &data[i];
+    } else if (type == GGML_TYPE_I16) {
+        v = (float) *(int16_t *) &data[i];
+    } else if (type == GGML_TYPE_I8) {
+        v = (float) *(int8_t *) &data[i];
+    } else if (type == GGML_TYPE_BF16) {
+        v = ggml_compute_bf16_to_fp32(*(ggml_bf16_t *) &data[i]);
+    } else {
+        GGML_ABORT("fatal error");
+    }
+    return v;
+}
+
+// Function to save a tensor to binary file
+static void save_tensor(struct ggml_tensor* tensor, const char* filename) {
+    FILE* f = fopen((std::string("reference/tensors/conv/") + std::string(filename)).c_str(), "wb");
+    if (!f) {
+        fprintf(stderr, "Failed to create file: %s\n", filename);
+        return;
+    }
+    
+    // Write shape
+    fwrite(tensor->ne, sizeof(int64_t), 4, f);
+    
+    // Calculate total elements
+    int64_t total_elements = tensor->ne[0] * tensor->ne[1] * tensor->ne[2] * tensor->ne[3];
+    
+    // Write data
+    fwrite(tensor->data, sizeof(float), total_elements, f);
+    fclose(f);
+}
+
+static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) {
+    GGML_ASSERT(n > 0);
+    float sum = 0;
+    for (int64_t i3 = 0; i3 < ne[3]; i3++) {
+        for (int64_t i2 = 0; i2 < ne[2]; i2++) {
+            for (int64_t i1 = 0; i1 < ne[1]; i1++) {
+                for (int64_t i0 = 0; i0 < ne[0]; i0++) {
+                    const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3);
+                    sum += v;
+                }
+            }
+        }
+    }
+    for (int64_t i3 = 0; i3 < ne[3]; i3++) {
+        LOG("                                     [\n");
+        for (int64_t i2 = 0; i2 < ne[2]; i2++) {
+            if (i2 == n && ne[2] > 2*n) {
+                LOG("                                      ..., \n");
+                i2 = ne[2] - n;
+            }
+            LOG("                                      [\n");
+            for (int64_t i1 = 0; i1 < ne[1]; i1++) {
+                if (i1 == n && ne[1] > 2*n) {
+                    LOG("                                       ..., \n");
+                    i1 = ne[1] - n;
+                }
+                LOG("                                       [");
+                for (int64_t i0 = 0; i0 < ne[0]; i0++) {
+                    if (i0 == n && ne[0] > 2*n) {
+                        LOG("..., ");
+                        i0 = ne[0] - n;
+                    }
+                    const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3);
+                    LOG("%12.4f", v);
+                    if (i0 < ne[0] - 1) LOG(", ");
+                }
+                LOG("],\n");
+            }
+            LOG("                                      ],\n");
+        }
+        LOG("                                     ]\n");
+        LOG("                                     sum = %f\n", sum);
+    }
+
+    // TODO: make this abort configurable/optional?
+    if (std::isnan(sum)) {
+        LOG_ERR("encountered NaN - aborting\n");
+        exit(0);
+    }
+}
+
+static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) {
+    auto * cb_data = (callback_data *) user_data;
+
+    const struct ggml_tensor * src0 = t->src[0];
+    const struct ggml_tensor * src1 = t->src[1];
+
+    if (ask) {
+        return true; // Always retrieve data
+    }
+
+    char src1_str[128] = {0};
+    if (src1) {
+        snprintf(src1_str, sizeof(src1_str), "%s{%s}", src1->name, ggml_ne_string(src1).c_str());
+    }
+
+    LOG("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__,
+         t->name, ggml_type_name(t->type), ggml_op_desc(t),
+         src0->name, ggml_ne_string(src0).c_str(),
+         src1 ? src1_str : "",
+         ggml_ne_string(t).c_str());
+
+
+    // copy the data from the GPU memory if needed
+    const bool is_host = ggml_backend_buffer_is_host(t->buffer);
+
+    if (!is_host) {
+        auto n_bytes = ggml_nbytes(t);
+        cb_data->data.resize(n_bytes);
+        ggml_backend_tensor_get(t, cb_data->data.data(), 0, n_bytes);
+    }
+
+    if (!ggml_is_quantized(t->type)) {
+        uint8_t * data = is_host ? (uint8_t *) t->data : cb_data->data.data();
+        ggml_print_tensor(data, t->type, t->ne, t->nb, 3);
+        if (std::string(t->name).substr(0, std::string("post_moe-").size()) == "post_moe-") {
+            if (cb_data->tensors.count(t->name) == 0) {
+                cb_data->tensors[t->name] = 1;
+            } else {
+                cb_data->tensors[t->name]++;
+            }
+            save_tensor(t, (std::string(t->name) + "_" + std::to_string(cb_data->tensors[t->name]) + ".bin").c_str());
+        }
+    }
+
+    return true;
+}
+
 int main(int argc, char ** argv) {
 int main(int argc, char ** argv) {
     common_params params;
     common_params params;
     g_params = &params;
     g_params = &params;
     if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_MAIN, print_usage)) {
     if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_MAIN, print_usage)) {
         return 1;
         return 1;
     }
     }
+    
+    // Check if cache statistics printing is enabled
+    print_cache_stats = params.dump_cache;
 
 
     common_init();
     common_init();
 
 
@@ -136,6 +311,9 @@ int main(int argc, char ** argv) {
     std::vector<common_chat_msg> chat_msgs;
     std::vector<common_chat_msg> chat_msgs;
 
 
     // load the model and apply lora adapter, if any
     // load the model and apply lora adapter, if any
+    callback_data cb_data;
+    params.cb_eval = ggml_debug;
+    params.cb_eval_user_data = &cb_data;
     LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__);
     LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__);
     common_init_result llama_init = common_init_from_params(params);
     common_init_result llama_init = common_init_from_params(params);
 
 
@@ -706,6 +884,9 @@ int main(int argc, char ** argv) {
             // LOG_DBG("last: %s\n", string_from(ctx, smpl->prev.to_vector()).c_str());
             // LOG_DBG("last: %s\n", string_from(ctx, smpl->prev.to_vector()).c_str());
 
 
             embd.push_back(id);
             embd.push_back(id);
+            
+            // Print cache statistics after each token generation
+            token_count++;
 
 
             // echo this to console
             // echo this to console
             input_echo = true;
             input_echo = true;