From 36e6e888b75ae93fb5aac212bb0e147d8379ae23 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sun, 11 May 2025 08:12:47 +0300 Subject: Fix race in the CUDA DeepSeek FA kernel (#406) Reference: https://github.com/ggml-org/llama.cpp/pull/13438 Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda/fattn-new-mma.cu | 2 ++ 1 file changed, 2 insertions(+) (limited to 'ggml/src') diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu index d1484451..8da96370 100644 --- a/ggml/src/ggml-cuda/fattn-new-mma.cu +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -898,6 +898,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); } + __syncthreads(); + // Write back combined meta data: #pragma unroll for (int imeta = 0; imeta < nmeta; ++imeta) { -- cgit v1.2.3