|
|
@@ -123,7 +123,7 @@ def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int =
|
|
|
print(f" sum = {t.sum().item():.6f}\n")
|
|
|
|
|
|
indexed_patterns = [ r"model\.layers\.[0-9]+_out", r"recurrent_cache_[0-9]+" ]
|
|
|
- non_indexed_patterns = [ r"k_pad", r"v_pad", r"q_pad" ]
|
|
|
+ non_indexed_patterns = [ r"k_pad", r"v_pad", r"q_scaled" ]
|
|
|
|
|
|
if any(re.fullmatch(p, name) for p in indexed_patterns):
|
|
|
if name not in token_counter:
|