diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2024-09-25 13:08:55 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-25 13:08:55 +0300 |
commit | 546f3ef349a7082fbc349897c3c7246baed2a6c6 (patch) | |
tree | 462896dc6a1167f2b4c866fa929ebc4a7230771d /ggml/src | |
parent | be57912955f3f6053a146a4062f7e2dc5a7d7a41 (diff) |
Use fp32 for K*Q in Metal FA implementation (#62)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src')
-rw-r--r-- | ggml/src/ggml-metal.metal | 15 |
1 files changed, 8 insertions, 7 deletions
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 6553f465..259fa609 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2904,11 +2904,11 @@ kernel void kernel_flash_attn_ext_vec_f16( const short iv3 = iq3 / rv3; // load the queries from shared memory into local memory - half4 mq[D4]; + float4 mq[D4]; for (short ii = 0; ii < D4; ii += NW) { short i = ii + tiisg; - mq[i] = sq4[i]; + mq[i] = (float4)sq4[i]; } // pointer to the mask @@ -2934,11 +2934,11 @@ kernel void kernel_flash_attn_ext_vec_f16( for (short ii = 0; ii < D4; ii += NW) { const short i = ii + tiisg; - half4x4 mk; - mk[0] = pk4[i + 0*(nb11/8)]; - mk[1] = pk4[i + 1*(nb11/8)]; - mk[2] = pk4[i + 2*(nb11/8)]; - mk[3] = pk4[i + 3*(nb11/8)]; + float4x4 mk; + mk[0] = (float4)pk4[i + 0*(nb11/8)]; + mk[1] = (float4)pk4[i + 1*(nb11/8)]; + mk[2] = (float4)pk4[i + 2*(nb11/8)]; + mk[3] = (float4)pk4[i + 3*(nb11/8)]; mqk += (float4) (mq[i] * mk); } @@ -2960,6 +2960,7 @@ kernel void kernel_flash_attn_ext_vec_f16( ss4[cc] = mqk; } + } } |