Explorar o código

convert.py : export rope freq_base when converting CodeLlama from an HF model (#2773)

slaren %!s(int64=2) %!d(string=hai) anos
pai
achega
12e2e33a97
Modificáronse 1 ficheiros con 18 adicións e 16 borrados
  1. 18 16
      convert.py

+ 18 - 16
convert.py

@@ -160,13 +160,14 @@ class Params:
     def loadHFTransformerJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
         config = json.load(open(config_path))
 
-        n_vocab    = config["vocab_size"]
-        n_embd     = config["hidden_size"]
-        n_layer    = config["num_hidden_layers"]
-        n_ff       = config["intermediate_size"]
-        n_head     = config["num_attention_heads"]
-        n_head_kv  = config["num_key_value_heads"] if "num_key_value_heads" in config else n_head
-        f_norm_eps = config["rms_norm_eps"]
+        n_vocab          = config["vocab_size"]
+        n_embd           = config["hidden_size"]
+        n_layer          = config["num_hidden_layers"]
+        n_ff             = config["intermediate_size"]
+        n_head           = config["num_attention_heads"]
+        n_head_kv        = config["num_key_value_heads"] if "num_key_value_heads" in config else n_head
+        f_norm_eps       = config["rms_norm_eps"]
+        f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None
 
         n_mult = Params.find_n_mult(n_ff, n_embd)
 
@@ -179,15 +180,16 @@ class Params:
                             "Suggestion: provide 'config.json' of the model in the same directory containing model files.")
 
         return Params(
-            n_vocab    = n_vocab,
-            n_embd     = n_embd,
-            n_mult     = n_mult,
-            n_layer    = n_layer,
-            n_ctx      = n_ctx,
-            n_ff       = n_ff,
-            n_head     = n_head,
-            n_head_kv  = n_head_kv,
-            f_norm_eps = f_norm_eps,
+            n_vocab          = n_vocab,
+            n_embd           = n_embd,
+            n_mult           = n_mult,
+            n_layer          = n_layer,
+            n_ctx            = n_ctx,
+            n_ff             = n_ff,
+            n_head           = n_head,
+            n_head_kv        = n_head_kv,
+            f_norm_eps       = f_norm_eps,
+            f_rope_freq_base = f_rope_freq_base,
         )
 
     # LLaMA v2 70B params.json