model-conversion : cast logits to float32
This commit is contained in:
parent
5266379bca
commit
292f8e231c
|
|
@ -200,7 +200,7 @@ with torch.no_grad():
|
||||||
logits = outputs.logits
|
logits = outputs.logits
|
||||||
|
|
||||||
# Extract logits for the last token (next token prediction)
|
# Extract logits for the last token (next token prediction)
|
||||||
last_logits = logits[0, -1, :].cpu().numpy()
|
last_logits = logits[0, -1, :].float().cpu().numpy()
|
||||||
|
|
||||||
print(f"Logits shape: {logits.shape}")
|
print(f"Logits shape: {logits.shape}")
|
||||||
print(f"Last token logits shape: {last_logits.shape}")
|
print(f"Last token logits shape: {last_logits.shape}")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue