From 25d1a0dca87e8d0237471e6d426a971e1d5289a2 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 25 Apr 2025 11:01:08 +0200 Subject: Fix FA on ARM (#346) Co-authored-by: Iwan Kawrakow --- ggml/src/iqk/iqk_mul_mat.cpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) (limited to 'ggml/src') diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 29ae9d53..45d804a4 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -16142,7 +16142,13 @@ struct FlashQKV { std::memcpy(S, fms.S, nq1*sizeof(float)); auto R = qkv_cache; for (int j = 0; j < nq1; ++j) { +#ifdef __aarch64__ + for (int i = 0; i < D/F16::block_size; ++i) { + F16::store(qkv + F16::block_size*i, F16::load(R + F16::block_size*i)); + } +#else std::memcpy(qkv, R, D*sizeof(float)); +#endif qkv += stride_qkv; R += D; } @@ -16162,7 +16168,13 @@ struct FlashQKV { std::memcpy(S, fms.S, q_step*sizeof(float)); auto R = qkv_cache; for (int j = 0; j < q_step; ++j) { +#ifdef __aarch64__ + for (int i = 0; i < D/F16::block_size; ++i) { + F16::store(qkv + F16::block_size*i, F16::load(R + F16::block_size*i)); + } +#else std::memcpy(qkv, R, D*sizeof(float)); +#endif qkv += stride_qkv; R += D; } -- cgit v1.2.3