diff --git a/examples/model-conversion/scripts/causal/run-org-model.py b/examples/model-conversion/scripts/causal/run-org-model.py index 7d2b80057c..da1132c003 100755 --- a/examples/model-conversion/scripts/causal/run-org-model.py +++ b/examples/model-conversion/scripts/causal/run-org-model.py @@ -200,7 +200,7 @@ with torch.no_grad(): logits = outputs.logits # Extract logits for the last token (next token prediction) - last_logits = logits[0, -1, :].cpu().numpy() + last_logits = logits[0, -1, :].float().cpu().numpy() print(f"Logits shape: {logits.shape}") print(f"Last token logits shape: {last_logits.shape}")