run-org-model.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. #!/usr/bin/env python3
  2. import argparse
  3. import os
  4. import importlib
  5. from pathlib import Path
  6. from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
  7. import torch
  8. import numpy as np
  9. unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME')
  10. parser = argparse.ArgumentParser(description='Process model with specified path')
  11. parser.add_argument('--model-path', '-m', help='Path to the model')
  12. args = parser.parse_args()
  13. model_path = os.environ.get('MODEL_PATH', args.model_path)
  14. if model_path is None:
  15. parser.error("Model path must be specified either via --model-path argument or MODEL_PATH environment variable")
  16. config = AutoConfig.from_pretrained(model_path)
  17. print("Model type: ", config.model_type)
  18. print("Vocab size: ", config.vocab_size)
  19. print("Hidden size: ", config.hidden_size)
  20. print("Number of layers: ", config.num_hidden_layers)
  21. print("BOS token id: ", config.bos_token_id)
  22. print("EOS token id: ", config.eos_token_id)
  23. print("Loading model and tokenizer using AutoTokenizer:", model_path)
  24. tokenizer = AutoTokenizer.from_pretrained(model_path)
  25. config = AutoConfig.from_pretrained(model_path)
  26. if unreleased_model_name:
  27. model_name_lower = unreleased_model_name.lower()
  28. unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
  29. class_name = f"{unreleased_model_name}ForCausalLM"
  30. print(f"Importing unreleased model module: {unreleased_module_path}")
  31. try:
  32. model_class = getattr(importlib.import_module(unreleased_module_path), class_name)
  33. model = model_class.from_pretrained(model_path) # Note: from_pretrained, not fromPretrained
  34. except (ImportError, AttributeError) as e:
  35. print(f"Failed to import or load model: {e}")
  36. exit(1)
  37. else:
  38. model = AutoModelForCausalLM.from_pretrained(model_path)
  39. model_name = os.path.basename(model_path)
  40. # Printing the Model class to allow for easier debugging. This can be useful
  41. # when working with models that have not been publicly released yet and this
  42. # migth require that the concrete class is imported and used directly instead
  43. # of using AutoModelForCausalLM.
  44. print(f"Model class: {model.__class__.__name__}")
  45. prompt = "Hello, my name is"
  46. input_ids = tokenizer(prompt, return_tensors="pt").input_ids
  47. print(f"Input tokens: {input_ids}")
  48. print(f"Input text: {repr(prompt)}")
  49. print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
  50. with torch.no_grad():
  51. outputs = model(input_ids)
  52. logits = outputs.logits
  53. # Extract logits for the last token (next token prediction)
  54. last_logits = logits[0, -1, :].cpu().numpy()
  55. print(f"Logits shape: {logits.shape}")
  56. print(f"Last token logits shape: {last_logits.shape}")
  57. print(f"Vocab size: {len(last_logits)}")
  58. data_dir = Path("data")
  59. data_dir.mkdir(exist_ok=True)
  60. bin_filename = data_dir / f"pytorch-{model_name}.bin"
  61. txt_filename = data_dir / f"pytorch-{model_name}.txt"
  62. # Save to file for comparison
  63. last_logits.astype(np.float32).tofile(bin_filename)
  64. # Also save as text file for easy inspection
  65. with open(txt_filename, "w") as f:
  66. for i, logit in enumerate(last_logits):
  67. f.write(f"{i}: {logit:.6f}\n")
  68. # Print some sample logits for quick verification
  69. print(f"First 10 logits: {last_logits[:10]}")
  70. print(f"Last 10 logits: {last_logits[-10:]}")
  71. # Show top 5 predicted tokens
  72. top_indices = np.argsort(last_logits)[-5:][::-1]
  73. print("Top 5 predictions:")
  74. for idx in top_indices:
  75. token = tokenizer.decode([idx])
  76. print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}")
  77. print(f"Saved bin logits to: {bin_filename}")
  78. print(f"Saved txt logist to: {txt_filename}")