llama.cpp/nllb_testing/debug_hf_nllb.py

130 lines
4.9 KiB
Python

#!/usr/bin/env python3
"""
Debug script to understand EXACTLY how HuggingFace NLLB generates translations.
We'll trace every step to replicate in llama.cpp.
"""
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import numpy as np
def main():
print("=== Loading NLLB Model ===")
model_name = "facebook/nllb-200-distilled-600M"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model.eval()
# Test input
text = "Hello"
src_lang = "eng_Latn"
tgt_lang = "fra_Latn"
print(f"\n=== Input ===")
print(f"Text: {text}")
print(f"Source: {src_lang} -> Target: {tgt_lang}")
# Step 1: Tokenize input
tokenizer.src_lang = src_lang
inputs = tokenizer(text, return_tensors="pt")
input_ids = inputs["input_ids"]
print(f"\n=== Step 1: Tokenization ===")
print(f"Input IDs: {input_ids.tolist()}")
print(f"Input tokens: {[tokenizer.decode([t]) for t in input_ids[0]]}")
# Step 2: Encode
print(f"\n=== Step 2: Encoder ===")
with torch.no_grad():
encoder_outputs = model.get_encoder()(input_ids)
print(f"Encoder output shape: {encoder_outputs.last_hidden_state.shape}")
print(f"Encoder output stats: mean={encoder_outputs.last_hidden_state.mean():.6f}, std={encoder_outputs.last_hidden_state.std():.6f}")
# Step 3: Prepare decoder input
tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
print(f"\n=== Step 3: Decoder Initialization ===")
print(f"Target language: {tgt_lang}")
print(f"Target language ID: {tgt_lang_id}")
print(f"BOS token ID: {model.config.bos_token_id}")
print(f"EOS token ID: {model.config.eos_token_id}")
print(f"Decoder start token ID: {model.config.decoder_start_token_id}")
print(f"PAD token ID: {model.config.pad_token_id}")
# Step 4: Manual decoding (without generate) to see what happens
print(f"\n=== Step 4: Manual Greedy Decoding ===")
# Start with decoder_start_token_id (which is EOS for NLLB) + target language
decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id, tgt_lang_id]])
print(f"Initial decoder input: {decoder_input_ids.tolist()}")
print(f"Initial tokens: {[tokenizer.decode([t]) for t in decoder_input_ids[0]]}")
max_length = 20
generated_tokens = []
for step in range(max_length):
print(f"\n--- Step {step} ---")
print(f"Decoder input shape: {decoder_input_ids.shape}")
print(f"Decoder input IDs: {decoder_input_ids[0].tolist()}")
with torch.no_grad():
outputs = model(
input_ids=None, # Already encoded
encoder_outputs=encoder_outputs,
decoder_input_ids=decoder_input_ids,
use_cache=False # Disable KV cache for debugging
)
# Get logits for the last token
logits = outputs.logits[0, -1, :]
print(f"Logits shape: {logits.shape}")
print(f"Logits stats: mean={logits.mean():.6f}, std={logits.std():.6f}, max={logits.max():.6f}")
# Get top-5 predictions
top_k = 5
top_logits, top_indices = torch.topk(logits, top_k)
print(f"Top {top_k} predictions:")
for i, (idx, logit) in enumerate(zip(top_indices, top_logits)):
token = tokenizer.decode([idx.item()])
print(f" {i+1}. Token {idx.item()}: '{token}' (logit: {logit.item():.4f})")
# Greedy: take the argmax
next_token = torch.argmax(logits).unsqueeze(0).unsqueeze(0)
next_token_id = next_token.item()
next_token_str = tokenizer.decode([next_token_id])
print(f"Selected token: {next_token_id} ('{next_token_str}')")
generated_tokens.append(next_token_id)
# Check for EOS
if next_token_id == model.config.eos_token_id:
print("EOS reached!")
break
# Append to decoder input
decoder_input_ids = torch.cat([decoder_input_ids, next_token], dim=1)
# Decode full output
print(f"\n=== Final Result ===")
print(f"Generated token IDs: {generated_tokens}")
translation = tokenizer.decode(generated_tokens, skip_special_tokens=True)
print(f"Translation: {translation}")
# Also test with .generate() for comparison
print(f"\n=== Comparison with .generate() ===")
forced_bos_token_id = tokenizer.convert_tokens_to_ids(tgt_lang)
generated_ids = model.generate(
inputs["input_ids"],
forced_bos_token_id=forced_bos_token_id,
max_length=20,
num_beams=1, # Greedy
do_sample=False
)
print(f"Generated IDs: {generated_ids[0].tolist()}")
translation_auto = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print(f"Translation: {translation_auto}")
if __name__ == "__main__":
main()