|
@@ -3065,6 +3065,9 @@ class Rwkv6Model(Model):
|
|
|
if new_name.endswith("time_mix_w2.weight"):
|
|
if new_name.endswith("time_mix_w2.weight"):
|
|
|
data_torch = data_torch.permute(0, 2, 1)
|
|
data_torch = data_torch.permute(0, 2, 1)
|
|
|
|
|
|
|
|
|
|
+ if new_name.endswith("time_mix_decay.weight") or "lerp" in new_name:
|
|
|
|
|
+ data_torch = data_torch.squeeze()
|
|
|
|
|
+
|
|
|
rescale_every_n_layers = self.hparams["rescale_every"]
|
|
rescale_every_n_layers = self.hparams["rescale_every"]
|
|
|
if rescale_every_n_layers > 0:
|
|
if rescale_every_n_layers > 0:
|
|
|
if new_name.endswith("time_mix_output.weight") or new_name.endswith("channel_mix_value.weight"):
|
|
if new_name.endswith("time_mix_output.weight") or new_name.endswith("channel_mix_value.weight"):
|