|
@@ -439,7 +439,7 @@ Vocab: TypeAlias = 'BpeVocab | SentencePieceVocab'
|
|
|
def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray:
|
|
def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray:
|
|
|
#print( "permute debug " + str(weights.shape[0]) + " x " + str(weights.shape[1]) + " nhead " + str(n_head) + " nheadkv " + str(n_kv_head) )
|
|
#print( "permute debug " + str(weights.shape[0]) + " x " + str(weights.shape[1]) + " nhead " + str(n_head) + " nheadkv " + str(n_kv_head) )
|
|
|
if n_head_kv is not None and n_head != n_head_kv:
|
|
if n_head_kv is not None and n_head != n_head_kv:
|
|
|
- n_head //= n_head_kv
|
|
|
|
|
|
|
+ n_head = n_head_kv
|
|
|
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
|
|
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
|
|
|
.swapaxes(1, 2)
|
|
.swapaxes(1, 2)
|
|
|
.reshape(weights.shape))
|
|
.reshape(weights.shape))
|