diff --git a/scripts/compare-logprobs.py b/scripts/compare-logprobs.py index 63861dd9a4..ac10085b78 100644 --- a/scripts/compare-logprobs.py +++ b/scripts/compare-logprobs.py @@ -25,16 +25,12 @@ Example usage: """ -def generate_input_prompt(length: int) -> list[str]: - CORPUS = """ - You are an advanced AI assistant capable of using tools to gather information, perform calculations, or execute tasks. Always think step by step before responding. If a user's query requires external data, computation, or actions beyond your internal knowledge, use the appropriate tools via function calls. - - ### Tool Call Format: - When you need to use a tool, output the call in this exact XML format. Include the opening and closing tags. Do not escape arguments; they will be parsed as plain text. - - You can make multiple calls in one go by placing them one after another. - """ - words = [w.strip() for w in CORPUS.strip().split(" ")] +def get_remote_corpus(url: str, length: int) -> list[str]: + response = requests.get(url) + response.raise_for_status() + corpus = response.text + words = [w.strip() for w in corpus.strip().split(" ")] + words = [w for w in words if "<" not in w] # make sure nothing looks like special tokens words = [w for w in words if len(w) > 0] # filter out empty strings while len(words) < length: words += words @@ -226,9 +222,9 @@ def parse_args() -> argparse.Namespace: ) parser_dump.add_argument( "--file", - type=Path, - default=None, - help="File containing prompt to use instead of the default", + type=str, + default="https://raw.githubusercontent.com/ggml-org/llama.cpp/eaba92c3dcc980ebe753348855d4a5d75c069997/tools/server/README.md", + help="File containing prompt to use instead of the default (can also be an URL)", ) parser_dump.add_argument( "--pattern", @@ -259,17 +255,19 @@ def main(): if args.verb == "dump": pattern = parse_pattern(args.pattern) - input_length = sum(n for _, n in pattern) - input_words = generate_input_prompt(input_length) - if args.file is not None: - with args.file.open("r") as f: + required_words = sum(n for _, n in pattern) + if args.file.startswith("http"): + input_words = get_remote_corpus(args.file, required_words) + logger.info(f"Fetched {len(input_words)} words from remote {args.file}") + else: + with open(args.file, "r") as f: input_words = f.read().strip().split(" ") - if input_length < sum(n for _, n in pattern): + input_words = [w for w in input_words if len(w) > 0] # filter out empty strings + if len(input_words) < required_words: raise ValueError( - f"Input file has only {input_length} words, but pattern requires at least {input_length} words." + f"Input file has only {len(input_words)} words, but pattern requires at least {required_words} words." ) - input_length = len(input_words) - logger.info(f"Using {input_length} words") + logger.info(f"Using {len(input_words)} words") dump_logits(args.endpoint, args.output, input_words, pattern, args.api_key) elif args.verb == "compare": compare_logits(args.input1, args.input2, args.output)