summaryrefslogtreecommitdiff
path: root/ggml/src/ggml-cuda
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml-cuda')
-rw-r--r--ggml/src/ggml-cuda/fattn-new-mma.cu8
1 files changed, 8 insertions, 0 deletions
diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu
index 5aeca3c4..6178b3e5 100644
--- a/ggml/src/ggml-cuda/fattn-new-mma.cu
+++ b/ggml/src/ggml-cuda/fattn-new-mma.cu
@@ -1093,6 +1093,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
}
}
+ // do we really need this?
+ __syncthreads();
+
// Write back combined meta data:
#pragma unroll
for (int imeta = 0; imeta < nmeta; ++imeta) {
@@ -1112,6 +1115,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
}
+ } else if (np > 1) {
+ // Warps with threadIdx.y % np == 0 execute a __syncthreads() in the if branch.
+ // Therefore, all other warps also need to execute a __syncthreads().
+ // Otherwise the points at which warps synchronize with each other would become misaligned.
+ __syncthreads();
}
#pragma unroll