run-casual-gen-embeddings-org.sh 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. #!/usr/bin/env python3
  2. import argparse
  3. import os
  4. import importlib
  5. import sys
  6. import torch
  7. import numpy as np
  8. from transformers import AutoTokenizer, AutoConfig, AutoModel, AutoModelForCausalLM
  9. from pathlib import Path
  10. unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME')
  11. parser = argparse.ArgumentParser(description='Process model with specified path')
  12. parser.add_argument('--model-path', '-m', help='Path to the model')
  13. args = parser.parse_args()
  14. model_path = os.environ.get('MODEL_PATH', args.model_path)
  15. if model_path is None:
  16. parser.error("Model path must be specified either via --model-path argument or MODEL_PATH environment variable")
  17. config = AutoConfig.from_pretrained(model_path)
  18. print("Model type: ", config.model_type)
  19. print("Vocab size: ", config.vocab_size)
  20. print("Hidden size: ", config.hidden_size)
  21. print("Number of layers: ", config.num_hidden_layers)
  22. print("BOS token id: ", config.bos_token_id)
  23. print("EOS token id: ", config.eos_token_id)
  24. print("Loading model and tokenizer using AutoTokenizer:", model_path)
  25. tokenizer = AutoTokenizer.from_pretrained(model_path)
  26. if unreleased_model_name:
  27. model_name_lower = unreleased_model_name.lower()
  28. unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
  29. class_name = f"{unreleased_model_name}ForCausalLM"
  30. print(f"Importing unreleased model module: {unreleased_module_path}")
  31. try:
  32. model_class = getattr(importlib.import_module(unreleased_module_path), class_name)
  33. model = model_class.from_pretrained(model_path)
  34. except (ImportError, AttributeError) as e:
  35. print(f"Failed to import or load model: {e}")
  36. else:
  37. model = AutoModelForCausalLM.from_pretrained(model_path)
  38. print(f"Model class: {type(model)}")
  39. #print(f"Model file: {type(model).__module__}")
  40. model_name = os.path.basename(model_path)
  41. print(f"Model name: {model_name}")
  42. prompt = "Hello world today"
  43. input_ids = tokenizer(prompt, return_tensors="pt").input_ids
  44. print(f"Input tokens: {input_ids}")
  45. print(f"Input text: {repr(prompt)}")
  46. print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
  47. with torch.no_grad():
  48. outputs = model(input_ids, output_hidden_states=True)
  49. # Extract hidden states from the last layer
  50. # outputs.hidden_states is a tuple of (num_layers + 1) tensors
  51. # Index -1 gets the last layer, shape: [batch_size, seq_len, hidden_size]
  52. last_hidden_states = outputs.hidden_states[-1]
  53. # Get embeddings for all tokens
  54. token_embeddings = last_hidden_states[0].cpu().numpy() # Remove batch dimension
  55. print(f"Hidden states shape: {last_hidden_states.shape}")
  56. print(f"Token embeddings shape: {token_embeddings.shape}")
  57. print(f"Hidden dimension: {token_embeddings.shape[-1]}")
  58. print(f"Number of tokens: {token_embeddings.shape[0]}")
  59. # Save raw token embeddings
  60. data_dir = Path("data")
  61. data_dir.mkdir(exist_ok=True)
  62. bin_filename = data_dir / f"pytorch-{model_name}-embeddings.bin"
  63. txt_filename = data_dir / f"pytorch-{model_name}-embeddings.txt"
  64. # Save all token embeddings as binary
  65. print(token_embeddings)
  66. token_embeddings.astype(np.float32).tofile(bin_filename)
  67. # Save as text for inspection
  68. with open(txt_filename, "w") as f:
  69. for i, embedding in enumerate(token_embeddings):
  70. for j, val in enumerate(embedding):
  71. f.write(f"{i} {j} {val:.6f}\n")
  72. # Print embeddings per token in the requested format
  73. print("\nToken embeddings:")
  74. tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
  75. for i, embedding in enumerate(token_embeddings):
  76. # Format: show first few values, ..., then last few values
  77. if len(embedding) > 10:
  78. # Show first 3 and last 3 values with ... in between
  79. first_vals = " ".join(f"{val:8.6f}" for val in embedding[:3])
  80. last_vals = " ".join(f"{val:8.6f}" for val in embedding[-3:])
  81. print(f"embedding {i}: {first_vals} ... {last_vals}")
  82. else:
  83. # If embedding is short, show all values
  84. vals = " ".join(f"{val:8.6f}" for val in embedding)
  85. print(f"embedding {i}: {vals}")
  86. # Also show token info for reference
  87. print(f"\nToken reference:")
  88. for i, token in enumerate(tokens):
  89. print(f" Token {i}: {repr(token)}")
  90. print(f"Saved bin logits to: {bin_filename}")
  91. print(f"Saved txt logist to: {txt_filename}")