download-pth.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import os
  2. import sys
  3. from tqdm import tqdm
  4. import requests
  5. if len(sys.argv) < 3:
  6. print("Usage: download-pth.py dir-model model-type\n")
  7. print(" model-type: Available models 7B, 13B, 30B or 65B")
  8. sys.exit(1)
  9. modelsDir = sys.argv[1]
  10. model = sys.argv[2]
  11. num = {
  12. "7B": 1,
  13. "13B": 2,
  14. "30B": 4,
  15. "65B": 8,
  16. }
  17. if model not in num:
  18. print(f"Error: model {model} is not valid, provide 7B, 13B, 30B or 65B")
  19. sys.exit(1)
  20. print(f"Downloading model {model}")
  21. files = ["checklist.chk", "params.json"]
  22. for i in range(num[model]):
  23. files.append(f"consolidated.0{i}.pth")
  24. resolved_path = os.path.abspath(os.path.join(modelsDir, model))
  25. os.makedirs(resolved_path, exist_ok=True)
  26. for file in files:
  27. dest_path = os.path.join(resolved_path, file)
  28. if os.path.exists(dest_path):
  29. print(f"Skip file download, it already exists: {file}")
  30. continue
  31. url = f"https://agi.gpt4.org/llama/LLaMA/{model}/{file}"
  32. response = requests.get(url, stream=True)
  33. with open(dest_path, 'wb') as f:
  34. with tqdm(unit='B', unit_scale=True, miniters=1, desc=file) as t:
  35. for chunk in response.iter_content(chunk_size=1024):
  36. if chunk:
  37. f.write(chunk)
  38. t.update(len(chunk))
  39. files2 = ["tokenizer_checklist.chk", "tokenizer.model"]
  40. for file in files2:
  41. dest_path = os.path.join(modelsDir, file)
  42. if os.path.exists(dest_path):
  43. print(f"Skip file download, it already exists: {file}")
  44. continue
  45. url = f"https://agi.gpt4.org/llama/LLaMA/{file}"
  46. response = requests.get(url, stream=True)
  47. with open(dest_path, 'wb') as f:
  48. with tqdm(unit='B', unit_scale=True, miniters=1, desc=file) as t:
  49. for chunk in response.iter_content(chunk_size=1024):
  50. if chunk:
  51. f.write(chunk)
  52. t.update(len(chunk))