run-casual-gen-embeddings-org.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. #!/usr/bin/env python3
  2. import argparse
  3. import os
  4. import importlib
  5. import torch
  6. import numpy as np
  7. from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
  8. from pathlib import Path
  9. unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME')
  10. parser = argparse.ArgumentParser(description='Process model with specified path')
  11. parser.add_argument('--model-path', '-m', help='Path to the model')
  12. args = parser.parse_args()
  13. model_path = os.environ.get('MODEL_PATH', args.model_path)
  14. if model_path is None:
  15. parser.error("Model path must be specified either via --model-path argument or MODEL_PATH environment variable")
  16. config = AutoConfig.from_pretrained(model_path)
  17. print("Model type: ", config.model_type)
  18. print("Vocab size: ", config.vocab_size)
  19. print("Hidden size: ", config.hidden_size)
  20. print("Number of layers: ", config.num_hidden_layers)
  21. print("BOS token id: ", config.bos_token_id)
  22. print("EOS token id: ", config.eos_token_id)
  23. print("Loading model and tokenizer using AutoTokenizer:", model_path)
  24. tokenizer = AutoTokenizer.from_pretrained(model_path)
  25. if unreleased_model_name:
  26. model_name_lower = unreleased_model_name.lower()
  27. unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
  28. class_name = f"{unreleased_model_name}ForCausalLM"
  29. print(f"Importing unreleased model module: {unreleased_module_path}")
  30. try:
  31. model_class = getattr(importlib.import_module(unreleased_module_path), class_name)
  32. model = model_class.from_pretrained(model_path)
  33. except (ImportError, AttributeError) as e:
  34. print(f"Failed to import or load model: {e}")
  35. print("Falling back to AutoModelForCausalLM")
  36. model = AutoModelForCausalLM.from_pretrained(model_path)
  37. else:
  38. model = AutoModelForCausalLM.from_pretrained(model_path)
  39. print(f"Model class: {type(model)}")
  40. #print(f"Model file: {type(model).__module__}")
  41. model_name = os.path.basename(model_path)
  42. print(f"Model name: {model_name}")
  43. prompt = "Hello world today"
  44. input_ids = tokenizer(prompt, return_tensors="pt").input_ids
  45. print(f"Input tokens: {input_ids}")
  46. print(f"Input text: {repr(prompt)}")
  47. print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
  48. with torch.no_grad():
  49. outputs = model(input_ids, output_hidden_states=True)
  50. # Extract hidden states from the last layer
  51. # outputs.hidden_states is a tuple of (num_layers + 1) tensors
  52. # Index -1 gets the last layer, shape: [batch_size, seq_len, hidden_size]
  53. last_hidden_states = outputs.hidden_states[-1]
  54. # Get embeddings for all tokens
  55. token_embeddings = last_hidden_states[0].cpu().numpy() # Remove batch dimension
  56. print(f"Hidden states shape: {last_hidden_states.shape}")
  57. print(f"Token embeddings shape: {token_embeddings.shape}")
  58. print(f"Hidden dimension: {token_embeddings.shape[-1]}")
  59. print(f"Number of tokens: {token_embeddings.shape[0]}")
  60. # Save raw token embeddings
  61. data_dir = Path("data")
  62. data_dir.mkdir(exist_ok=True)
  63. bin_filename = data_dir / f"pytorch-{model_name}-embeddings.bin"
  64. txt_filename = data_dir / f"pytorch-{model_name}-embeddings.txt"
  65. # Save all token embeddings as binary
  66. print(token_embeddings)
  67. token_embeddings.astype(np.float32).tofile(bin_filename)
  68. # Save as text for inspection
  69. with open(txt_filename, "w") as f:
  70. for i, embedding in enumerate(token_embeddings):
  71. for j, val in enumerate(embedding):
  72. f.write(f"{i} {j} {val:.6f}\n")
  73. # Print embeddings per token in the requested format
  74. print("\nToken embeddings:")
  75. tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
  76. for i, embedding in enumerate(token_embeddings):
  77. # Format: show first few values, ..., then last few values
  78. if len(embedding) > 10:
  79. # Show first 3 and last 3 values with ... in between
  80. first_vals = " ".join(f"{val:8.6f}" for val in embedding[:3])
  81. last_vals = " ".join(f"{val:8.6f}" for val in embedding[-3:])
  82. print(f"embedding {i}: {first_vals} ... {last_vals}")
  83. else:
  84. # If embedding is short, show all values
  85. vals = " ".join(f"{val:8.6f}" for val in embedding)
  86. print(f"embedding {i}: {vals}")
  87. # Also show token info for reference
  88. print(f"\nToken reference:")
  89. for i, token in enumerate(tokens):
  90. print(f" Token {i}: {repr(token)}")
  91. print(f"Saved bin logits to: {bin_filename}")
  92. print(f"Saved txt logist to: {txt_filename}")