summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohannes Gäßler <johannesg@5d6.de>2024-06-12 17:41:51 +0200
committerGitHub <noreply@github.com>2024-06-12 17:41:51 +0200
commit963552903f51043ee947a8deeaaa7ec00bc3f1a4 (patch)
tree1ca4308d7300f4ef2e0628f2057f73692e4bcc34
parenta9cae48003dfc4fe95b8f5c81682fc6e63425235 (diff)
CUDA: fix broken oob check for FA vec f32 kernel (#7904)
-rw-r--r--ggml-cuda/fattn-vec-f32.cuh2
1 files changed, 1 insertions, 1 deletions
diff --git a/ggml-cuda/fattn-vec-f32.cuh b/ggml-cuda/fattn-vec-f32.cuh
index ddf0c837..11a5e355 100644
--- a/ggml-cuda/fattn-vec-f32.cuh
+++ b/ggml-cuda/fattn-vec-f32.cuh
@@ -149,7 +149,7 @@ static __global__ void flash_attn_vec_ext_f32(
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
- Q_f2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
+ Q_f2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
Q_f2[j][i0/WARP_SIZE].x *= scale;
Q_f2[j][i0/WARP_SIZE].y *= scale;
}