| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281 |
- import argparse
- import requests
- import json
- from pathlib import Path
- import logging
- logger = logging.getLogger("compare-logprobs")
- logging.basicConfig(level=logging.INFO)
- DESCRIPTION = """
- Compare logits between llama.cpp and another inference engine using OpenAI-compatible server endpoints.
- Unlike compare-logits.py, it allows dumping logits from a hosted API endpoint. Useful when it's not possible to run both models locally.
- Example usage:
- Step 1: Dump logits from two different servers
- python scripts/compare-logprobs.py dump logits_llama.log http://localhost:8080/v1/completions
- python scripts/compare-logprobs.py dump logits_other.log http://other-engine:8000/v1/completions
- (optionally, you can add --api-key <key> if the endpoint requires authentication)
- Step 2: Compare the dumped logits
- python scripts/compare-logprobs.py compare logits_llama.log logits_other.log report.md
- """
- def generate_input_prompt(length: int) -> list[str]:
- CORPUS = """
- You are an advanced AI assistant capable of using tools to gather information, perform calculations, or execute tasks. Always think step by step before responding. If a user's query requires external data, computation, or actions beyond your internal knowledge, use the appropriate tools via function calls.
- ### Tool Call Format:
- When you need to use a tool, output the call in this exact XML format. Include the opening and closing tags. Do not escape arguments; they will be parsed as plain text.
- You can make multiple calls in one go by placing them one after another.
- """
- words = [w.strip() for w in CORPUS.strip().split(" ")]
- words = [w for w in words if len(w) > 0] # filter out empty strings
- while len(words) < length:
- words += words
- return words[:length]
- def dump_logits(
- endpoint: str,
- output_path: Path,
- input_words: list[str],
- pattern: list[tuple[bool, int]],
- api_key=None,
- ):
- logger.info(f"Dumping logits to {output_path} from endpoint {endpoint}...")
- words = input_words
- curr_text = ""
- n_total = sum(n for get, n in pattern if get)
- n_done = 0
- i_cur = 0
- i_total = len(words)
- with output_path.open("w") as f:
- for get, n in pattern:
- if not get:
- # skip n words
- for i in range(n):
- curr_text += words.pop(0) + " "
- i_cur += 1
- continue
- # get n words
- for i in range(n):
- curr_text += words.pop(0) + " "
- payload = {
- "prompt": curr_text.strip(),
- "temperature": 0.0,
- "top_k": 1,
- "max_tokens": 1,
- "logprobs": 1,
- "stream": False,
- }
- response = requests.post(
- endpoint,
- json=payload,
- headers={"Authorization": f"Bearer {api_key}"} if api_key else {},
- )
- response.raise_for_status()
- data = response.json()
- data["__index"] = i_cur # add index for easier debugging later
- data = json.dumps(data)
- f.write(f"{data}\n")
- n_done += 1
- i_cur += 1
- logger.info(
- f"\n\n{data}\n\n[Step: {n_done}/{n_total} | Word: {i_cur}/{i_total}]"
- )
- logger.info(f"Logits dumped to {output_path}")
- def get_token_logprobs(data: dict):
- logprobs = data["choices"][0]["logprobs"]
- if "content" in logprobs:
- # llama.cpp case
- top = logprobs["content"][0]["top_logprobs"][0]
- return top["token"], top["logprob"]
- else:
- # vllm case
- tokens = logprobs["tokens"]
- token_logprobs = logprobs["token_logprobs"]
- return tokens[0], token_logprobs[0]
- def clean_text(text: str) -> str:
- return (
- "'"
- + text.replace("\n", "\\n")
- .replace("\t", "\\t")
- .replace("\r", "\\r")
- .replace("|", "\\|")
- + "'"
- )
- def compare_logits(input1: Path, input2: Path, output_path: Path):
- with input1.open("r") as f1, input2.open("r") as f2, output_path.open("w") as fout:
- lines1 = f1.readlines()
- lines2 = f2.readlines()
- tab_header = [
- "idx",
- input1.name,
- "logprob_1",
- input2.name,
- "logprob_2",
- "diff (abs)",
- ]
- tab_entries = []
- tab_max_widths = [len(h) for h in tab_header]
- assert len(lines1) == len(
- lines2
- ), "Input files must have the same number of lines."
- fout.write("# Logits Comparison Report\n\n")
- for i, (line1, line2) in enumerate(zip(lines1, lines2)):
- if not line1.strip() or not line2.strip():
- continue # skip empty lines
- data1 = json.loads(line1)
- data2 = json.loads(line2)
- idx1 = data1.get("__index", -1)
- idx2 = data2.get("__index", -1)
- if idx1 != idx2:
- logger.warning(
- f"Warning: Mismatched indices at line {i}: {idx1} vs {idx2}"
- )
- token1, logprob1 = get_token_logprobs(data1)
- token2, logprob2 = get_token_logprobs(data2)
- token1 = clean_text(token1)
- token2 = clean_text(token2)
- abs_diff = abs(logprob1 - logprob2)
- tab_entries.append(
- (
- str(idx1 + 1),
- token1,
- f"{logprob1:.4f}",
- token2,
- f"{logprob2:.4f}",
- f"{(abs_diff):.4f}",
- )
- )
- for i in range(len(tab_entries)):
- for j in range(len(tab_header)):
- tab_max_widths[j] = max(tab_max_widths[j], len(tab_entries[i][j]))
- output = ""
- for j in range(len(tab_header)):
- output += f"| {tab_header[j]:<{tab_max_widths[j]}} "
- output += "|\n"
- for j in range(len(tab_header)):
- output += f"|{'-' * (tab_max_widths[j] + 2)}"
- output += "|\n"
- for entry in tab_entries:
- for j in range(len(tab_header)):
- output += f"| {entry[j]:<{tab_max_widths[j]}} "
- output += "|\n"
- logger.info("\n" + output)
- fout.write(output)
- logger.info(f"Report written to {output_path}")
- def parse_pattern(pattern: str) -> list[tuple[bool, int]]:
- parts = pattern.split(",")
- result = []
- for i, part in enumerate(parts):
- n = int(part)
- if i % 2 == 0:
- result.append((True, n)) # get n words
- else:
- result.append((False, n)) # skip n words
- return result
- def parse_args() -> argparse.Namespace:
- parser = argparse.ArgumentParser(
- description=DESCRIPTION, formatter_class=argparse.RawTextHelpFormatter
- )
- subparsers = parser.add_subparsers(
- dest="verb", required=True, help="action to perform"
- )
- # dump subcommand
- parser_dump = subparsers.add_parser("dump", help="dump logits from an endpoint")
- parser_dump.add_argument(
- "output", type=Path, help="output path for dumped logits (.log)"
- )
- parser_dump.add_argument(
- "endpoint", type=str, help="OAI-compat /completions endpoint"
- )
- parser_dump.add_argument(
- "--api-key",
- type=str,
- default=None,
- help="API key for authentication (if required)",
- )
- parser_dump.add_argument(
- "--file",
- type=Path,
- default=None,
- help="File containing prompt to use instead of the default",
- )
- parser_dump.add_argument(
- "--pattern",
- type=str,
- default="10,1000,10,4000,10",
- help="Pattern n_get,n_skip,... where n_get is number of words to get and n_skip is number of words to skip (num of words, NOT num of tokens)",
- )
- # compare subcommand
- parser_compare = subparsers.add_parser(
- "compare", help="compare two dumped logits files"
- )
- parser_compare.add_argument("input1", type=Path, help="first input file (.log)")
- parser_compare.add_argument("input2", type=Path, help="second input file (.log)")
- parser_compare.add_argument(
- "output", type=Path, help="output path for comparison report (.md)"
- )
- try:
- return parser.parse_args()
- except Exception as e:
- parser.print_help()
- raise e
- def main():
- args = parse_args()
- if args.verb == "dump":
- pattern = parse_pattern(args.pattern)
- input_length = sum(n for _, n in pattern)
- input_words = generate_input_prompt(input_length)
- if args.file is not None:
- with args.file.open("r") as f:
- input_words = f.read().strip().split(" ")
- if input_length < sum(n for _, n in pattern):
- raise ValueError(
- f"Input file has only {input_length} words, but pattern requires at least {input_length} words."
- )
- input_length = len(input_words)
- logger.info(f"Using {input_length} words")
- dump_logits(args.endpoint, args.output, input_words, pattern, args.api_key)
- elif args.verb == "compare":
- compare_logits(args.input1, args.input2, args.output)
- else:
- raise ValueError(f"Unknown verb: {args.verb}")
- if __name__ == "__main__":
- main()
|