diff --git a/scripts/compare-logprobs.py b/scripts/compare-logprobs.py index c7775a7917..ac10085b78 100644 --- a/scripts/compare-logprobs.py +++ b/scripts/compare-logprobs.py @@ -30,7 +30,7 @@ def get_remote_corpus(url: str, length: int) -> list[str]: response.raise_for_status() corpus = response.text words = [w.strip() for w in corpus.strip().split(" ")] - words = [w if "<" in w else "TEST" for w in words] # make sure nothing looks like a special token + words = [w for w in words if "<" not in w] # make sure nothing looks like special tokens words = [w for w in words if len(w) > 0] # filter out empty strings while len(words) < length: words += words