summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-04-25 11:01:08 +0200
committerGitHub <noreply@github.com>2025-04-25 11:01:08 +0200
commit25d1a0dca87e8d0237471e6d426a971e1d5289a2 (patch)
treec96506f472b24fe26df3dc13bf7b7f8ef205086d
parentf176122a3d50c781414458b498b9426086a91647 (diff)
Fix FA on ARM (#346)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp12
1 files changed, 12 insertions, 0 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 29ae9d53..45d804a4 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -16142,7 +16142,13 @@ struct FlashQKV {
std::memcpy(S, fms.S, nq1*sizeof(float));
auto R = qkv_cache;
for (int j = 0; j < nq1; ++j) {
+#ifdef __aarch64__
+ for (int i = 0; i < D/F16::block_size; ++i) {
+ F16::store(qkv + F16::block_size*i, F16::load(R + F16::block_size*i));
+ }
+#else
std::memcpy(qkv, R, D*sizeof(float));
+#endif
qkv += stride_qkv;
R += D;
}
@@ -16162,7 +16168,13 @@ struct FlashQKV {
std::memcpy(S, fms.S, q_step*sizeof(float));
auto R = qkv_cache;
for (int j = 0; j < q_step; ++j) {
+#ifdef __aarch64__
+ for (int i = 0; i < D/F16::block_size; ++i) {
+ F16::store(qkv + F16::block_size*i, F16::load(R + F16::block_size*i));
+ }
+#else
std::memcpy(qkv, R, D*sizeof(float));
+#endif
qkv += stride_qkv;
R += D;
}