scripts: update corpus of compare-logprobs (#19326)
* scripts: update corpus of compare-logprobs * fix
This commit is contained in:
parent
8fdf269dad
commit
c747294b2d
|
|
@ -25,16 +25,12 @@ Example usage:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def generate_input_prompt(length: int) -> list[str]:
|
def get_remote_corpus(url: str, length: int) -> list[str]:
|
||||||
CORPUS = """
|
response = requests.get(url)
|
||||||
You are an advanced AI assistant capable of using tools to gather information, perform calculations, or execute tasks. Always think step by step before responding. If a user's query requires external data, computation, or actions beyond your internal knowledge, use the appropriate tools via function calls.
|
response.raise_for_status()
|
||||||
|
corpus = response.text
|
||||||
### Tool Call Format:
|
words = [w.strip() for w in corpus.strip().split(" ")]
|
||||||
When you need to use a tool, output the call in this exact XML format. Include the opening and closing tags. Do not escape arguments; they will be parsed as plain text.
|
words = [w for w in words if "<" not in w] # make sure nothing looks like special tokens
|
||||||
|
|
||||||
You can make multiple calls in one go by placing them one after another.
|
|
||||||
"""
|
|
||||||
words = [w.strip() for w in CORPUS.strip().split(" ")]
|
|
||||||
words = [w for w in words if len(w) > 0] # filter out empty strings
|
words = [w for w in words if len(w) > 0] # filter out empty strings
|
||||||
while len(words) < length:
|
while len(words) < length:
|
||||||
words += words
|
words += words
|
||||||
|
|
@ -226,9 +222,9 @@ def parse_args() -> argparse.Namespace:
|
||||||
)
|
)
|
||||||
parser_dump.add_argument(
|
parser_dump.add_argument(
|
||||||
"--file",
|
"--file",
|
||||||
type=Path,
|
type=str,
|
||||||
default=None,
|
default="https://raw.githubusercontent.com/ggml-org/llama.cpp/eaba92c3dcc980ebe753348855d4a5d75c069997/tools/server/README.md",
|
||||||
help="File containing prompt to use instead of the default",
|
help="File containing prompt to use instead of the default (can also be an URL)",
|
||||||
)
|
)
|
||||||
parser_dump.add_argument(
|
parser_dump.add_argument(
|
||||||
"--pattern",
|
"--pattern",
|
||||||
|
|
@ -259,17 +255,19 @@ def main():
|
||||||
|
|
||||||
if args.verb == "dump":
|
if args.verb == "dump":
|
||||||
pattern = parse_pattern(args.pattern)
|
pattern = parse_pattern(args.pattern)
|
||||||
input_length = sum(n for _, n in pattern)
|
required_words = sum(n for _, n in pattern)
|
||||||
input_words = generate_input_prompt(input_length)
|
if args.file.startswith("http"):
|
||||||
if args.file is not None:
|
input_words = get_remote_corpus(args.file, required_words)
|
||||||
with args.file.open("r") as f:
|
logger.info(f"Fetched {len(input_words)} words from remote {args.file}")
|
||||||
|
else:
|
||||||
|
with open(args.file, "r") as f:
|
||||||
input_words = f.read().strip().split(" ")
|
input_words = f.read().strip().split(" ")
|
||||||
if input_length < sum(n for _, n in pattern):
|
input_words = [w for w in input_words if len(w) > 0] # filter out empty strings
|
||||||
|
if len(input_words) < required_words:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Input file has only {input_length} words, but pattern requires at least {input_length} words."
|
f"Input file has only {len(input_words)} words, but pattern requires at least {required_words} words."
|
||||||
)
|
)
|
||||||
input_length = len(input_words)
|
logger.info(f"Using {len(input_words)} words")
|
||||||
logger.info(f"Using {input_length} words")
|
|
||||||
dump_logits(args.endpoint, args.output, input_words, pattern, args.api_key)
|
dump_logits(args.endpoint, args.output, input_words, pattern, args.api_key)
|
||||||
elif args.verb == "compare":
|
elif args.verb == "compare":
|
||||||
compare_logits(args.input1, args.input2, args.output)
|
compare_logits(args.input1, args.input2, args.output)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue