Просмотр исходного кода

finetune : fix #3404 (#3437)

the shapes for init model of gqa models was wrong
xaedes 2 лет назад
Родитель
Сommit
a03ce38455
1 измененных файлов с 2 добавлено и 2 удалено
  1. 2 2
      examples/finetune/finetune.cpp

+ 2 - 2
examples/finetune/finetune.cpp

@@ -332,8 +332,8 @@ static void init_model(struct llama_model * input, struct my_llama_model * model
 
 
         assert_shape_1d(layer.attention_norm, hparams.n_embd);
         assert_shape_1d(layer.attention_norm, hparams.n_embd);
         assert_shape_2d(layer.wq,             hparams.n_embd, hparams.n_embd);
         assert_shape_2d(layer.wq,             hparams.n_embd, hparams.n_embd);
-        assert_shape_2d(layer.wk,             hparams.n_embd, hparams.n_embd);
-        assert_shape_2d(layer.wv,             hparams.n_embd, hparams.n_embd);
+        assert_shape_2d(layer.wk,             hparams.n_embd, hparams.n_embd_gqa());
+        assert_shape_2d(layer.wv,             hparams.n_embd, hparams.n_embd_gqa());
         assert_shape_2d(layer.wo,             hparams.n_embd, hparams.n_embd);
         assert_shape_2d(layer.wo,             hparams.n_embd, hparams.n_embd);
         assert_shape_1d(layer.ffn_norm,       hparams.n_embd);
         assert_shape_1d(layer.ffn_norm,       hparams.n_embd);
         assert_shape_2d(layer.w1,             hparams.n_embd, hparams.n_ff);
         assert_shape_2d(layer.w1,             hparams.n_embd, hparams.n_ff);