summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorZhang Peiyuan <a1286225768@gmail.com>2023-09-28 02:45:20 +0800
committerGitHub <noreply@github.com>2023-09-27 20:45:20 +0200
commite519621010cac02c6fec0f8f3b16cda0591042c0 (patch)
treebf403a765d8ce4e06abee9c576ded6c9d27fc243
parentac43576124a75c2de6e333ac31a3444ff9eb9458 (diff)
convert : remove bug in convert.py permute function (#3364)
-rwxr-xr-xconvert.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/convert.py b/convert.py
index 4ac5030d..8bb6c7e4 100755
--- a/convert.py
+++ b/convert.py
@@ -439,7 +439,7 @@ Vocab: TypeAlias = 'BpeVocab | SentencePieceVocab'
def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray:
#print( "permute debug " + str(weights.shape[0]) + " x " + str(weights.shape[1]) + " nhead " + str(n_head) + " nheadkv " + str(n_kv_head) )
if n_head_kv is not None and n_head != n_head_kv:
- n_head //= n_head_kv
+ n_head = n_head_kv
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
.swapaxes(1, 2)
.reshape(weights.shape))