Просмотр исходного кода

model-conversion : cast logits to float32 (#18009)

Georgi Gerganov 1 месяц назад
Родитель
Сommit
77ad8542bd
1 измененных файлов с 1 добавлено и 1 удалено
  1. 1 1
      examples/model-conversion/scripts/causal/run-org-model.py

+ 1 - 1
examples/model-conversion/scripts/causal/run-org-model.py

@@ -200,7 +200,7 @@ with torch.no_grad():
     logits = outputs.logits
 
     # Extract logits for the last token (next token prediction)
-    last_logits = logits[0, -1, :].cpu().numpy()
+    last_logits = logits[0, -1, :].float().cpu().numpy()
 
     print(f"Logits shape: {logits.shape}")
     print(f"Last token logits shape: {last_logits.shape}")