diff options
Diffstat (limited to 'ggml-cuda/fattn-tile-f16.cu')
| -rw-r--r-- | ggml-cuda/fattn-tile-f16.cu | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/ggml-cuda/fattn-tile-f16.cu b/ggml-cuda/fattn-tile-f16.cu index 3d64a9eb..cb11d721 100644 --- a/ggml-cuda/fattn-tile-f16.cu +++ b/ggml-cuda/fattn-tile-f16.cu @@ -278,13 +278,13 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * constexpr int D = 64; constexpr int nwarps = 8; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>; - launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block); + launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>; - launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block); + launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; default: { GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128."); |
