diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-05-11 08:12:47 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-05-11 08:12:47 +0300 |
commit | 36e6e888b75ae93fb5aac212bb0e147d8379ae23 (patch) | |
tree | 30c6c834fe25250c8d1b2defc5282ce5564257fa | |
parent | a2d24c97e5c5c28aeb3669dcc0044b69258a85ca (diff) |
Fix race in the CUDA DeepSeek FA kernel (#406)
Reference: https://github.com/ggml-org/llama.cpp/pull/13438
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/ggml-cuda/fattn-new-mma.cu | 2 |
1 files changed, 2 insertions, 0 deletions
diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu index d1484451..8da96370 100644 --- a/ggml/src/ggml-cuda/fattn-new-mma.cu +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -898,6 +898,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); } + __syncthreads(); + // Write back combined meta data: #pragma unroll for (int imeta = 0; imeta < nmeta; ++imeta) { |