summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-09-15 20:17:24 +0300
committerGitHub <noreply@github.com>2023-09-15 20:17:24 +0300
commitc6f1491da032238241e01021c8c58d7b540a043f (patch)
treef5edfd61ee40e916510188619b8ae5bd7ad480e2
parente3d87a6c36eadd84d58763143c6d56a0c771ca40 (diff)
metal : fix bug in soft_max kernels (out-of-bounds access) (#3194)
-rw-r--r--ggml-metal.metal4
1 files changed, 2 insertions, 2 deletions
diff --git a/ggml-metal.metal b/ggml-metal.metal
index 3087ecda..7f1c3d9e 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -118,7 +118,7 @@ kernel void kernel_soft_max(
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
// parallel max
- float lmax = psrc0[tpitg[0]];
+ float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY;
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
lmax = MAX(lmax, psrc0[i00]);
}
@@ -158,7 +158,7 @@ kernel void kernel_soft_max_4(
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
// parallel max
- float4 lmax4 = psrc4[tpitg[0]];
+ float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY;
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
lmax4 = fmax(lmax4, psrc4[i00]);
}