|
|
@@ -50,10 +50,14 @@ def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int =
|
|
|
- 2D tensors (seq, hidden)
|
|
|
- 3D tensors (batch, seq, hidden)
|
|
|
- 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head
|
|
|
+ - 5D tensors
|
|
|
|
|
|
Shows first and last max_vals of each vector per sequence position.
|
|
|
"""
|
|
|
t = tensor.detach().to(torch.float32).cpu()
|
|
|
+ ten_shape = t.shape
|
|
|
+ while t.ndim > 4:
|
|
|
+ t = t.squeeze(0)
|
|
|
|
|
|
# Determine dimensions
|
|
|
if t.ndim == 3:
|
|
|
@@ -63,12 +67,11 @@ def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int =
|
|
|
t = t.unsqueeze(0)
|
|
|
elif t.ndim == 4:
|
|
|
_, s, _, _ = t.shape
|
|
|
+
|
|
|
else:
|
|
|
print(f"Skipping tensor due to unsupported dimensions: {t.ndim}")
|
|
|
return
|
|
|
|
|
|
- ten_shape = t.shape
|
|
|
-
|
|
|
print(f"ggml_debug: {name} = (f32) ... = {{{ten_shape}}}")
|
|
|
print(" [")
|
|
|
print(" [")
|