summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml/src/ggml-metal.metal15
1 files changed, 8 insertions, 7 deletions
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index 6553f465..259fa609 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -2904,11 +2904,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
const short iv3 = iq3 / rv3;
// load the queries from shared memory into local memory
- half4 mq[D4];
+ float4 mq[D4];
for (short ii = 0; ii < D4; ii += NW) {
short i = ii + tiisg;
- mq[i] = sq4[i];
+ mq[i] = (float4)sq4[i];
}
// pointer to the mask
@@ -2934,11 +2934,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
for (short ii = 0; ii < D4; ii += NW) {
const short i = ii + tiisg;
- half4x4 mk;
- mk[0] = pk4[i + 0*(nb11/8)];
- mk[1] = pk4[i + 1*(nb11/8)];
- mk[2] = pk4[i + 2*(nb11/8)];
- mk[3] = pk4[i + 3*(nb11/8)];
+ float4x4 mk;
+ mk[0] = (float4)pk4[i + 0*(nb11/8)];
+ mk[1] = (float4)pk4[i + 1*(nb11/8)];
+ mk[2] = (float4)pk4[i + 2*(nb11/8)];
+ mk[3] = (float4)pk4[i + 3*(nb11/8)];
mqk += (float4) (mq[i] * mk);
}
@@ -2960,6 +2960,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
ss4[cc] = mqk;
}
+
}
}