130 lines
4.9 KiB
Python
130 lines
4.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Debug script to understand EXACTLY how HuggingFace NLLB generates translations.
|
|
We'll trace every step to replicate in llama.cpp.
|
|
"""
|
|
import torch
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
import numpy as np
|
|
|
|
def main():
|
|
print("=== Loading NLLB Model ===")
|
|
model_name = "facebook/nllb-200-distilled-600M"
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
|
model.eval()
|
|
|
|
# Test input
|
|
text = "Hello"
|
|
src_lang = "eng_Latn"
|
|
tgt_lang = "fra_Latn"
|
|
|
|
print(f"\n=== Input ===")
|
|
print(f"Text: {text}")
|
|
print(f"Source: {src_lang} -> Target: {tgt_lang}")
|
|
|
|
# Step 1: Tokenize input
|
|
tokenizer.src_lang = src_lang
|
|
inputs = tokenizer(text, return_tensors="pt")
|
|
input_ids = inputs["input_ids"]
|
|
|
|
print(f"\n=== Step 1: Tokenization ===")
|
|
print(f"Input IDs: {input_ids.tolist()}")
|
|
print(f"Input tokens: {[tokenizer.decode([t]) for t in input_ids[0]]}")
|
|
|
|
# Step 2: Encode
|
|
print(f"\n=== Step 2: Encoder ===")
|
|
with torch.no_grad():
|
|
encoder_outputs = model.get_encoder()(input_ids)
|
|
|
|
print(f"Encoder output shape: {encoder_outputs.last_hidden_state.shape}")
|
|
print(f"Encoder output stats: mean={encoder_outputs.last_hidden_state.mean():.6f}, std={encoder_outputs.last_hidden_state.std():.6f}")
|
|
|
|
# Step 3: Prepare decoder input
|
|
tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
|
|
print(f"\n=== Step 3: Decoder Initialization ===")
|
|
print(f"Target language: {tgt_lang}")
|
|
print(f"Target language ID: {tgt_lang_id}")
|
|
print(f"BOS token ID: {model.config.bos_token_id}")
|
|
print(f"EOS token ID: {model.config.eos_token_id}")
|
|
print(f"Decoder start token ID: {model.config.decoder_start_token_id}")
|
|
print(f"PAD token ID: {model.config.pad_token_id}")
|
|
|
|
# Step 4: Manual decoding (without generate) to see what happens
|
|
print(f"\n=== Step 4: Manual Greedy Decoding ===")
|
|
|
|
# Start with decoder_start_token_id (which is EOS for NLLB) + target language
|
|
decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id, tgt_lang_id]])
|
|
print(f"Initial decoder input: {decoder_input_ids.tolist()}")
|
|
print(f"Initial tokens: {[tokenizer.decode([t]) for t in decoder_input_ids[0]]}")
|
|
|
|
max_length = 20
|
|
generated_tokens = []
|
|
|
|
for step in range(max_length):
|
|
print(f"\n--- Step {step} ---")
|
|
print(f"Decoder input shape: {decoder_input_ids.shape}")
|
|
print(f"Decoder input IDs: {decoder_input_ids[0].tolist()}")
|
|
|
|
with torch.no_grad():
|
|
outputs = model(
|
|
input_ids=None, # Already encoded
|
|
encoder_outputs=encoder_outputs,
|
|
decoder_input_ids=decoder_input_ids,
|
|
use_cache=False # Disable KV cache for debugging
|
|
)
|
|
|
|
# Get logits for the last token
|
|
logits = outputs.logits[0, -1, :]
|
|
print(f"Logits shape: {logits.shape}")
|
|
print(f"Logits stats: mean={logits.mean():.6f}, std={logits.std():.6f}, max={logits.max():.6f}")
|
|
|
|
# Get top-5 predictions
|
|
top_k = 5
|
|
top_logits, top_indices = torch.topk(logits, top_k)
|
|
print(f"Top {top_k} predictions:")
|
|
for i, (idx, logit) in enumerate(zip(top_indices, top_logits)):
|
|
token = tokenizer.decode([idx.item()])
|
|
print(f" {i+1}. Token {idx.item()}: '{token}' (logit: {logit.item():.4f})")
|
|
|
|
# Greedy: take the argmax
|
|
next_token = torch.argmax(logits).unsqueeze(0).unsqueeze(0)
|
|
next_token_id = next_token.item()
|
|
next_token_str = tokenizer.decode([next_token_id])
|
|
|
|
print(f"Selected token: {next_token_id} ('{next_token_str}')")
|
|
|
|
generated_tokens.append(next_token_id)
|
|
|
|
# Check for EOS
|
|
if next_token_id == model.config.eos_token_id:
|
|
print("EOS reached!")
|
|
break
|
|
|
|
# Append to decoder input
|
|
decoder_input_ids = torch.cat([decoder_input_ids, next_token], dim=1)
|
|
|
|
# Decode full output
|
|
print(f"\n=== Final Result ===")
|
|
print(f"Generated token IDs: {generated_tokens}")
|
|
translation = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
|
print(f"Translation: {translation}")
|
|
|
|
# Also test with .generate() for comparison
|
|
print(f"\n=== Comparison with .generate() ===")
|
|
forced_bos_token_id = tokenizer.convert_tokens_to_ids(tgt_lang)
|
|
generated_ids = model.generate(
|
|
inputs["input_ids"],
|
|
forced_bos_token_id=forced_bos_token_id,
|
|
max_length=20,
|
|
num_beams=1, # Greedy
|
|
do_sample=False
|
|
)
|
|
print(f"Generated IDs: {generated_ids[0].tolist()}")
|
|
translation_auto = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
|
print(f"Translation: {translation_auto}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|