1
0

compare-logprobs.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. import argparse
  2. import requests
  3. import json
  4. from pathlib import Path
  5. import logging
  6. logger = logging.getLogger("compare-logprobs")
  7. logging.basicConfig(level=logging.INFO)
  8. DESCRIPTION = """
  9. Compare logits between llama.cpp and another inference engine using OpenAI-compatible server endpoints.
  10. Unlike compare-logits.py, it allows dumping logits from a hosted API endpoint. Useful when it's not possible to run both models locally.
  11. Example usage:
  12. Step 1: Dump logits from two different servers
  13. python scripts/compare-logprobs.py dump logits_llama.log http://localhost:8080/v1/completions
  14. python scripts/compare-logprobs.py dump logits_other.log http://other-engine:8000/v1/completions
  15. (optionally, you can add --api-key <key> if the endpoint requires authentication)
  16. Step 2: Compare the dumped logits
  17. python scripts/compare-logprobs.py compare logits_llama.log logits_other.log report.md
  18. """
  19. def generate_input_prompt(length: int) -> list[str]:
  20. CORPUS = """
  21. You are an advanced AI assistant capable of using tools to gather information, perform calculations, or execute tasks. Always think step by step before responding. If a user's query requires external data, computation, or actions beyond your internal knowledge, use the appropriate tools via function calls.
  22. ### Tool Call Format:
  23. When you need to use a tool, output the call in this exact XML format. Include the opening and closing tags. Do not escape arguments; they will be parsed as plain text.
  24. You can make multiple calls in one go by placing them one after another.
  25. """
  26. words = [w.strip() for w in CORPUS.strip().split(" ")]
  27. words = [w for w in words if len(w) > 0] # filter out empty strings
  28. while len(words) < length:
  29. words += words
  30. return words[:length]
  31. def dump_logits(
  32. endpoint: str,
  33. output_path: Path,
  34. input_words: list[str],
  35. pattern: list[tuple[bool, int]],
  36. api_key=None,
  37. ):
  38. logger.info(f"Dumping logits to {output_path} from endpoint {endpoint}...")
  39. words = input_words
  40. curr_text = ""
  41. n_total = sum(n for get, n in pattern if get)
  42. n_done = 0
  43. i_cur = 0
  44. i_total = len(words)
  45. with output_path.open("w") as f:
  46. for get, n in pattern:
  47. if not get:
  48. # skip n words
  49. for i in range(n):
  50. curr_text += words.pop(0) + " "
  51. i_cur += 1
  52. continue
  53. # get n words
  54. for i in range(n):
  55. curr_text += words.pop(0) + " "
  56. payload = {
  57. "prompt": curr_text.strip(),
  58. "temperature": 0.0,
  59. "top_k": 1,
  60. "max_tokens": 1,
  61. "logprobs": 1,
  62. "stream": False,
  63. }
  64. response = requests.post(
  65. endpoint,
  66. json=payload,
  67. headers={"Authorization": f"Bearer {api_key}"} if api_key else {},
  68. )
  69. response.raise_for_status()
  70. data = response.json()
  71. data["__index"] = i_cur # add index for easier debugging later
  72. data = json.dumps(data)
  73. f.write(f"{data}\n")
  74. n_done += 1
  75. i_cur += 1
  76. logger.info(
  77. f"\n\n{data}\n\n[Step: {n_done}/{n_total} | Word: {i_cur}/{i_total}]"
  78. )
  79. logger.info(f"Logits dumped to {output_path}")
  80. def get_token_logprobs(data: dict):
  81. logprobs = data["choices"][0]["logprobs"]
  82. if "content" in logprobs:
  83. # llama.cpp case
  84. top = logprobs["content"][0]["top_logprobs"][0]
  85. return top["token"], top["logprob"]
  86. else:
  87. # vllm case
  88. tokens = logprobs["tokens"]
  89. token_logprobs = logprobs["token_logprobs"]
  90. return tokens[0], token_logprobs[0]
  91. def clean_text(text: str) -> str:
  92. return (
  93. "'"
  94. + text.replace("\n", "\\n")
  95. .replace("\t", "\\t")
  96. .replace("\r", "\\r")
  97. .replace("|", "\\|")
  98. + "'"
  99. )
  100. def compare_logits(input1: Path, input2: Path, output_path: Path):
  101. with input1.open("r") as f1, input2.open("r") as f2, output_path.open("w") as fout:
  102. lines1 = f1.readlines()
  103. lines2 = f2.readlines()
  104. tab_header = [
  105. "idx",
  106. input1.name,
  107. "logprob_1",
  108. input2.name,
  109. "logprob_2",
  110. "diff (abs)",
  111. ]
  112. tab_entries = []
  113. tab_max_widths = [len(h) for h in tab_header]
  114. assert len(lines1) == len(
  115. lines2
  116. ), "Input files must have the same number of lines."
  117. fout.write("# Logits Comparison Report\n\n")
  118. for i, (line1, line2) in enumerate(zip(lines1, lines2)):
  119. if not line1.strip() or not line2.strip():
  120. continue # skip empty lines
  121. data1 = json.loads(line1)
  122. data2 = json.loads(line2)
  123. idx1 = data1.get("__index", -1)
  124. idx2 = data2.get("__index", -1)
  125. if idx1 != idx2:
  126. logger.warning(
  127. f"Warning: Mismatched indices at line {i}: {idx1} vs {idx2}"
  128. )
  129. token1, logprob1 = get_token_logprobs(data1)
  130. token2, logprob2 = get_token_logprobs(data2)
  131. token1 = clean_text(token1)
  132. token2 = clean_text(token2)
  133. abs_diff = abs(logprob1 - logprob2)
  134. tab_entries.append(
  135. (
  136. str(idx1 + 1),
  137. token1,
  138. f"{logprob1:.4f}",
  139. token2,
  140. f"{logprob2:.4f}",
  141. f"{(abs_diff):.4f}",
  142. )
  143. )
  144. for i in range(len(tab_entries)):
  145. for j in range(len(tab_header)):
  146. tab_max_widths[j] = max(tab_max_widths[j], len(tab_entries[i][j]))
  147. output = ""
  148. for j in range(len(tab_header)):
  149. output += f"| {tab_header[j]:<{tab_max_widths[j]}} "
  150. output += "|\n"
  151. for j in range(len(tab_header)):
  152. output += f"|{'-' * (tab_max_widths[j] + 2)}"
  153. output += "|\n"
  154. for entry in tab_entries:
  155. for j in range(len(tab_header)):
  156. output += f"| {entry[j]:<{tab_max_widths[j]}} "
  157. output += "|\n"
  158. logger.info("\n" + output)
  159. fout.write(output)
  160. logger.info(f"Report written to {output_path}")
  161. def parse_pattern(pattern: str) -> list[tuple[bool, int]]:
  162. parts = pattern.split(",")
  163. result = []
  164. for i, part in enumerate(parts):
  165. n = int(part)
  166. if i % 2 == 0:
  167. result.append((True, n)) # get n words
  168. else:
  169. result.append((False, n)) # skip n words
  170. return result
  171. def parse_args() -> argparse.Namespace:
  172. parser = argparse.ArgumentParser(
  173. description=DESCRIPTION, formatter_class=argparse.RawTextHelpFormatter
  174. )
  175. subparsers = parser.add_subparsers(
  176. dest="verb", required=True, help="action to perform"
  177. )
  178. # dump subcommand
  179. parser_dump = subparsers.add_parser("dump", help="dump logits from an endpoint")
  180. parser_dump.add_argument(
  181. "output", type=Path, help="output path for dumped logits (.log)"
  182. )
  183. parser_dump.add_argument(
  184. "endpoint", type=str, help="OAI-compat /completions endpoint"
  185. )
  186. parser_dump.add_argument(
  187. "--api-key",
  188. type=str,
  189. default=None,
  190. help="API key for authentication (if required)",
  191. )
  192. parser_dump.add_argument(
  193. "--file",
  194. type=Path,
  195. default=None,
  196. help="File containing prompt to use instead of the default",
  197. )
  198. parser_dump.add_argument(
  199. "--pattern",
  200. type=str,
  201. default="10,1000,10,4000,10",
  202. help="Pattern n_get,n_skip,... where n_get is number of words to get and n_skip is number of words to skip (num of words, NOT num of tokens)",
  203. )
  204. # compare subcommand
  205. parser_compare = subparsers.add_parser(
  206. "compare", help="compare two dumped logits files"
  207. )
  208. parser_compare.add_argument("input1", type=Path, help="first input file (.log)")
  209. parser_compare.add_argument("input2", type=Path, help="second input file (.log)")
  210. parser_compare.add_argument(
  211. "output", type=Path, help="output path for comparison report (.md)"
  212. )
  213. try:
  214. return parser.parse_args()
  215. except Exception as e:
  216. parser.print_help()
  217. raise e
  218. def main():
  219. args = parse_args()
  220. if args.verb == "dump":
  221. pattern = parse_pattern(args.pattern)
  222. input_length = sum(n for _, n in pattern)
  223. input_words = generate_input_prompt(input_length)
  224. if args.file is not None:
  225. with args.file.open("r") as f:
  226. input_words = f.read().strip().split(" ")
  227. if input_length < sum(n for _, n in pattern):
  228. raise ValueError(
  229. f"Input file has only {input_length} words, but pattern requires at least {input_length} words."
  230. )
  231. input_length = len(input_words)
  232. logger.info(f"Using {input_length} words")
  233. dump_logits(args.endpoint, args.output, input_words, pattern, args.api_key)
  234. elif args.verb == "compare":
  235. compare_logits(args.input1, args.input2, args.output)
  236. else:
  237. raise ValueError(f"Unknown verb: {args.verb}")
  238. if __name__ == "__main__":
  239. main()