summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml-cuda/fattn-tile-f16.cu4
-rw-r--r--ggml-cuda/fattn-tile-f32.cu4
-rw-r--r--ggml-cuda/fattn-vec-f16.cu6
-rw-r--r--ggml-cuda/fattn-vec-f32.cu6
4 files changed, 18 insertions, 2 deletions
diff --git a/ggml-cuda/fattn-tile-f16.cu b/ggml-cuda/fattn-tile-f16.cu
index 4a07ac6a..586d469c 100644
--- a/ggml-cuda/fattn-tile-f16.cu
+++ b/ggml-cuda/fattn-tile-f16.cu
@@ -238,6 +238,10 @@ static __global__ void flash_attn_tile_ext_f16(
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
const int j_VKQ = j_VKQ_0 + threadIdx.y;
+ if (ic0 + j_VKQ >= ne01) {
+ return;
+ }
+
half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
kqsum_j = warp_reduce_sum(kqsum_j);
diff --git a/ggml-cuda/fattn-tile-f32.cu b/ggml-cuda/fattn-tile-f32.cu
index b8b2f69e..b6ef8eb4 100644
--- a/ggml-cuda/fattn-tile-f32.cu
+++ b/ggml-cuda/fattn-tile-f32.cu
@@ -237,6 +237,10 @@ static __global__ void flash_attn_tile_ext_f32(
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
const int j_VKQ = j_VKQ_0 + threadIdx.y;
+ if (ic0 + j_VKQ >= ne01) {
+ return;
+ }
+
float kqsum_j = kqsum[j_VKQ_0/nwarps];
kqsum_j = warp_reduce_sum(kqsum_j);
diff --git a/ggml-cuda/fattn-vec-f16.cu b/ggml-cuda/fattn-vec-f16.cu
index 54e1ac5d..7352dcab 100644
--- a/ggml-cuda/fattn-vec-f16.cu
+++ b/ggml-cuda/fattn-vec-f16.cu
@@ -212,6 +212,10 @@ static __global__ void flash_attn_vec_ext_f16(
#pragma unroll
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
+ if (ic0 + j_VKQ >= ne01) {
+ break;
+ }
+
kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
@@ -223,7 +227,7 @@ static __global__ void flash_attn_vec_ext_f16(
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
}
- if (parallel_blocks != 1 && tid < ncols) {
+ if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01) {
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
}
#else
diff --git a/ggml-cuda/fattn-vec-f32.cu b/ggml-cuda/fattn-vec-f32.cu
index 5bcabd09..11476a6c 100644
--- a/ggml-cuda/fattn-vec-f32.cu
+++ b/ggml-cuda/fattn-vec-f32.cu
@@ -200,6 +200,10 @@ static __global__ void flash_attn_vec_ext_f32(
#pragma unroll
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
+ if (ic0 + j_VKQ >= ne01) {
+ break;
+ }
+
kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
@@ -211,7 +215,7 @@ static __global__ void flash_attn_vec_ext_f32(
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
}
- if (parallel_blocks != 1 && tid < ncols) {
+ if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01) {
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
}
}