diff options
-rw-r--r-- | ggml/src/CMakeLists.txt | 2 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/cp-async.cuh | 10 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/fattn-new-mma.cu | 451 | ||||
-rw-r--r-- | src/llama.cpp | 23 |
4 files changed, 356 insertions, 130 deletions
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 74ac5374..4f4337c2 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -986,7 +986,7 @@ endif() set(CUDA_CXX_FLAGS "") if (GGML_CUDA) - set(CUDA_FLAGS -use_fast_math) + set(CUDA_FLAGS -use_fast_math -extended-lambda) if (GGML_FATAL_WARNINGS) list(APPEND CUDA_FLAGS -Werror all-warnings) diff --git a/ggml/src/ggml-cuda/cp-async.cuh b/ggml/src/ggml-cuda/cp-async.cuh index ecb65999..a87dc247 100644 --- a/ggml/src/ggml-cuda/cp-async.cuh +++ b/ggml/src/ggml-cuda/cp-async.cuh @@ -2,6 +2,16 @@ #include "common.cuh" +static __device__ __forceinline__ unsigned int ggml_cuda_cvta_generic_to_shared(void * generic_ptr) { +#ifdef CP_ASYNC_AVAILABLE + return __cvta_generic_to_shared(generic_ptr); +#else + GGML_UNUSED(generic_ptr); + NO_DEVICE_CODE; + return 0; +#endif // CP_ASYNC_AVAILABLE +} + // Copies data from global to shared memory, cg == cache global. // Both the src and dst pointers must be aligned to 16 bit. // Shared memory uses 32 bit addressing, the pointer is passed as unsigned int. diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu index d1484451..630baf33 100644 --- a/ggml/src/ggml-cuda/fattn-new-mma.cu +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -1,15 +1,16 @@ -// Adapted from https://github.com/ggml-org/llama.cpp/pull/13306 +// Adapted from https://github.com/ggml-org/llama.cpp/pull/13435 // -// Copyright (C) 2023-2024 The ggml authors -// Copyright (C) 2024 Iwan Kawrakow +// Copyright (C) 2025 The ggml authors +// Copyright (C) 2025 Iwan Kawrakow // MIT license // SPDX-License-Identifier: MIT // -#include "fattn-new-mma.cuh" +#include "common.cuh" #include "cp-async.cuh" #include "mma_new.cuh" #include "fattn-common.cuh" +#include "fattn-new-mma.cuh" using namespace ggml_cuda_mma; @@ -39,6 +40,8 @@ struct fattn_mma_f16_config; // The previous MMA version is better (faster) // I'm keeping these around commented out for now, // and only using the 576, 512 case. +// Perhaps the 256 head size needs a closer look +// to see if this implementation is better. // //template <> //struct fattn_mma_f16_config< 64, 64> { @@ -46,9 +49,30 @@ struct fattn_mma_f16_config; // static constexpr int nwarps_max = 4; // static constexpr bool Q_in_reg = true; // static constexpr int nstages_target = 2; -// static constexpr int nbatch_K2 = 32; -// static constexpr int nbatch_V2 = 32; -// static constexpr int nbatch_combine = 32; +// +// static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { +// return 32; +// } +// +// static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { +// return 32; +// } +// +// static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { +// return 32; +// } +// +// static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { +// return 32; +// } +// +// static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { +// return 32; +// } +// +// static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { +// return 32; +// } //}; // //template <> @@ -57,9 +81,30 @@ struct fattn_mma_f16_config; // static constexpr int nwarps_max = 4; // static constexpr bool Q_in_reg = true; // static constexpr int nstages_target = 2; -// static constexpr int nbatch_K2 = 40; -// static constexpr int nbatch_V2 = 40; -// static constexpr int nbatch_combine = 40; +// +// static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { +// return 40; +// } +// +// static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { +// return 40; +// } +// +// static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { +// return 40; +// } +// +// static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { +// return 40; +// } +// +// static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { +// return 40; +// } +// +// static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { +// return 40; +// } //}; // //template <> @@ -68,9 +113,30 @@ struct fattn_mma_f16_config; // static constexpr int nwarps_max = 4; // static constexpr bool Q_in_reg = true; // static constexpr int nstages_target = 2; -// static constexpr int nbatch_K2 = 48; -// static constexpr int nbatch_V2 = 48; -// static constexpr int nbatch_combine = 48; +// +// static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { +// return 48; +// } +// +// static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { +// return 48; +// } +// +// static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { +// return 48; +// } +// +// static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { +// return 48; +// } +// +// static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { +// return 48; +// } +// +// static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { +// return 48; +// } //}; // //template <> @@ -79,9 +145,30 @@ struct fattn_mma_f16_config; // static constexpr int nwarps_max = 4; // static constexpr bool Q_in_reg = true; // static constexpr int nstages_target = 2; -// static constexpr int nbatch_K2 = 56; -// static constexpr int nbatch_V2 = 56; -// static constexpr int nbatch_combine = 56; +// +// static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { +// return 56; +// } +// +// static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { +// return 56; +// } +// +// static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { +// return 56; +// } +// +// static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { +// return 56; +// } +// +// static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { +// return 56; +// } +// +// static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { +// return 56; +// } //}; // //template <> @@ -90,20 +177,30 @@ struct fattn_mma_f16_config; // static constexpr int nwarps_max = 4; // static constexpr bool Q_in_reg = true; // static constexpr int nstages_target = 2; -// static constexpr int nbatch_K2 = 64; -// static constexpr int nbatch_V2 = 64; -// static constexpr int nbatch_combine = 64; -//}; // -//template <> -//struct fattn_mma_f16_config<192, 128> { -// static constexpr int nbatch_fa = 64; -// static constexpr int nwarps_max = 4; -// static constexpr bool Q_in_reg = true; -// static constexpr int nstages_target = 2; -// static constexpr int nbatch_K2 = 96; -// static constexpr int nbatch_V2 = 64; -// static constexpr int nbatch_combine = 64; +// static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { +// return 64; +// } +// +// static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { +// return 64; +// } +// +// static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { +// return 64; +// } +// +// static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { +// return 64; +// } +// +// static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { +// return 64; +// } +// +// static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { +// return 64; +// } //}; // //template <> @@ -112,9 +209,38 @@ struct fattn_mma_f16_config; // static constexpr int nwarps_max = 4; // static constexpr bool Q_in_reg = true; // static constexpr int nstages_target = 2; -// static constexpr int nbatch_K2 = 128; -// static constexpr int nbatch_V2 = 128; -// static constexpr int nbatch_combine = 128; +// +// static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { +// return 128; +// } +// +// static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { +// return 128; +// } +// +// static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { +// return 128; +// } +// +// static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { +// return 128; +// } +// +// static int get_nbatch_combine_host(const int cc, const int ncols) { +// if (ggml_cuda_highest_compiled_arch(cc) == CC_TURING) { +// return ncols <= 16 ? 128 : 64; +// } +// return 64; +// } +// +// static constexpr __device__ int get_nbatch_combine_device(int ncols) { +//#if __CUDA_ARCH__ == CC_TURING +// return ncols <= 16 ? 128 : 64; +//#else +// GGML_UNUSED(ncols); +// return 128; +//#endif // __CUDA_ARCH__ == CC_TURING +// } //}; template <> @@ -123,9 +249,65 @@ struct fattn_mma_f16_config<576, 512> { static constexpr int nwarps_max = 8; static constexpr bool Q_in_reg = false; static constexpr int nstages_target = 1; - static constexpr int nbatch_K2 = 160; - static constexpr int nbatch_V2 = 128; - static constexpr int nbatch_combine = 128; + + static int get_nbatch_K2_host(const int cc, const int ncols) { + if (ggml_cuda_highest_compiled_arch(cc) == CC_TURING) { + return ncols <= 16 ? 96 : 160; + } + return ncols <= 16 ? 288 : 160; + } + + static constexpr __device__ int get_nbatch_K2_device(int ncols) { +#if __CUDA_ARCH__ == CC_TURING + return ncols <= 16 ? 96 : 160; +#else + return ncols <= 16 ? 288 : 160; +#endif // __CUDA_ARCH__ == CC_TURING + } + + static int get_nbatch_V2_host(const int cc, const int ncols) { + if (ggml_cuda_highest_compiled_arch(cc) == CC_TURING) { + return ncols <= 16 ? 64 : 128; + } + return ncols <= 16 ? 256 : 128; + } + + static constexpr __device__ int get_nbatch_V2_device(int ncols) { +#if __CUDA_ARCH__ == GML_CUDA_CC_TURING + return ncols <= 16 ? 64 : 128; +#else + return ncols <= 16 ? 256 : 128; +#endif // __CUDA_ARCH__ == GML_CUDA_CC_TURING + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 128; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 128; + } +}; + +// ------------------------------------------------------------------------------------------------------------------ + +// The compiler is always able to unroll loops if they contain continue expressions. +// In such cases loop unrolling can still be achieved via recursion: +template <int n> +struct ggml_cuda_unroll { + template <typename Func, typename... Args> + __device__ void operator()(const Func & f, Args... args) const { + f(n - 1, args...); + ggml_cuda_unroll<n - 1>{}(f, args...); + } +}; + +template <> +struct ggml_cuda_unroll<1> { + template <typename Func, typename... Args> + __device__ void operator()(const Func & f, Args... args) const { + f(0, args...); + } }; template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async> @@ -136,26 +318,25 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes. if constexpr (use_cp_async) { - const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV); - constexpr int preload = 64; constexpr int h2_per_chunk = 16/sizeof(half2); - const int chunks_per_row = D2 / h2_per_chunk; - int k0_start = 0; -#pragma unroll - for (int stride_k = WARP_SIZE; stride_k > WARP_SIZE/32; stride_k >>= 1) { + const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV); + + auto load = [&] __device__ (auto n) { + const int stride_k = WARP_SIZE >> n; + const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k); + const int stride_i = WARP_SIZE / stride_k; if (k0_start == k0_stop) { - continue; + return; } - const int stride_i = WARP_SIZE / stride_k; #pragma unroll for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { - const int i = i0 + threadIdx.y*stride_i + threadIdx.x / stride_k; + const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { break; @@ -168,18 +349,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk); } } - k0_start = k0_stop; - } + }; + ggml_cuda_unroll<5>{}(load); } else { static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds"); -#pragma unroll - for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + auto load = [&] __device__ (const int n) { + const int stride_k = WARP_SIZE >> n; const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k); const int k0_stop = D2 - D2 % (1*stride_k); const int stride_i = WARP_SIZE / stride_k; - if (k0_start == k0_stop || k0_stop <= 0) { - continue; + if (k0_start == k0_stop) { + return; } #pragma unroll @@ -197,7 +378,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( tile_KV[i*stride_tile + k] = KV[i*stride_KV + k]; } } - } + }; + ggml_cuda_unroll<3>{}(load); } } @@ -211,7 +393,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa; constexpr int stride_j = nwarps * cols_per_warp; - const unsigned int tile_mask_32 = __cvta_generic_to_shared(tile_mask); + const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask); #pragma unroll for (int j0 = 0; j0 < ncols1; j0 += stride_j) { @@ -245,7 +427,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( } } -template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter> +template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter> static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, @@ -283,10 +465,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( constexpr int cols_per_warp = ntiles * tile_B::I; constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int ncols = ncols1 * ncols2; + constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols); + constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols); - constexpr int stride_tile_Q = DKQ/2 + 4; - constexpr int stride_tile_K = c::nbatch_K2 + 4; - constexpr int stride_tile_V = c::nbatch_V2 + 4; + constexpr int stride_tile_Q = DKQ/2 + 4; + constexpr int stride_tile_K = nbatch_K2 + 4; + + static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA"); + constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4; const int k_VKQ_0 = kb0 * c::nbatch_fa; tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles]; @@ -297,29 +484,30 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C; if constexpr (nstages > 1) { - static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading"); + static_assert(!mla, "multi-stage loading not implemented for MLA"); + static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading"); constexpr bool use_cp_async = true; cp_async_wait_all(); __syncthreads(); flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async> - (V_h2 + k_VKQ_0*stride_V, tile_V, c::nbatch_V2, stride_V); + (V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V); } else { constexpr bool use_cp_async = nstages == 1; - if (ncols2 > 1 || mask_h2) { + if constexpr (ncols2 > 1 || mask_h2) { flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask); } } #pragma unroll - for (int k0_start = 0; k0_start < DKQ/2; k0_start += c::nbatch_K2) { - const int k0_stop = k0_start + c::nbatch_K2 < DKQ/2 ? k0_start + c::nbatch_K2 : DKQ/2; + for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) { + const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2; const int k0_diff = k0_stop - k0_start; - if (nstages <= 1) { + if constexpr (nstages <= 1) { constexpr bool use_cp_async = nstages == 1; flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async> (K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K); - if (use_cp_async) { + if constexpr (use_cp_async) { cp_async_wait_all(); } __syncthreads(); @@ -334,7 +522,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) { tile_A K_A; load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); - if (ntiles == 1) { + if constexpr (ntiles == 1) { mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]); } else { #pragma unroll @@ -364,12 +552,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } - if (nstages <= 1) { + if constexpr (nstages <= 1) { __syncthreads(); // Only needed if tile_K == tile_V. } } - if (use_logit_softcap) { + if constexpr (use_logit_softcap) { static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); #pragma unroll for (int i = 0; i < c::nbatch_fa/(np*tile_C_KQ::I) * ntiles; ++i) { @@ -387,8 +575,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } float KQ_rowsum_add[cols_per_thread] = {0.0f}; - if (ntiles == 1) { - if (ncols2 > 1 || mask_h2) { + if constexpr (ntiles == 1) { + if constexpr (ncols2 > 1 || mask_h2) { #pragma unroll for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) { const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I; @@ -506,7 +694,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col]; } - if (ntiles == 1) { + if constexpr (ntiles == 1) { const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); #pragma unroll for (int i = 0; i < DV/tile_C_VKQ::I; ++i) { @@ -534,7 +722,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( tile_B B[c::nbatch_fa/(np*2*tile_B::J) * ntiles]; tile_B_16 * B_16 = (tile_B_16 *) B; static_assert(c::nbatch_fa % (np*2*tile_B::J) == 0, "bad loop size"); - if (ntiles == 1) { + if constexpr (ntiles == 1) { #pragma unroll for (int k = 0; k < c::nbatch_fa/(np*2*tile_B::J); ++k) { B[k] = get_transposed(get_half2(KQ_C[k])); @@ -548,7 +736,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } - if (nstages > 1) { + if constexpr (nstages > 1) { // Preload K tile for next iteration: constexpr bool use_cp_async = true; cp_async_wait_all(); @@ -559,24 +747,30 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask); } flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async> - (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, c::nbatch_K2, stride_K); + (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K); } } + + // For MLA K and V have the same data. + // Therefore, iterate over V in reverse and re-use the data if possible. + static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented"); + constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV; #pragma unroll - for (int i0_start = 0; i0_start < DV; i0_start += 2*c::nbatch_V2) { - const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV; - const int i0_diff = i0_stop - i0_start; + for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) { + const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0; + const int i0_diff = i0_stop - i0_start; - if (nstages == 1) { + if (nstages <= 1 && i0_start < reusable_cutoff) { constexpr bool use_cp_async = nstages == 1; flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async> (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V); - if (use_cp_async) { + if constexpr (use_cp_async) { cp_async_wait_all(); } __syncthreads(); } + const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2; // Calculate VKQ tile: #pragma unroll @@ -587,8 +781,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int k0 = k00 + (threadIdx.y % np)*tile_A::J; tile_A A; - load_ldmatrix_trans(A, tile_V + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); - if (ntiles == 1) { + load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); + if constexpr (ntiles == 1) { mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); } else { #pragma unroll @@ -600,7 +794,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } - if (nstages <= 1) { + if constexpr (nstages <= 1) { __syncthreads(); // Only needed if tile_K == tile_V. } } @@ -618,7 +812,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #endif // INT8_MMA_AVAILABLE } -template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup> +template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup> static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, @@ -654,13 +848,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr int cols_per_warp = ntiles * tile_B::I; constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols); + constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols); static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps"); - constexpr int stride_tile_Q = DKQ/2 + 4; - constexpr int stride_tile_K = c::nbatch_K2 + 4; - constexpr int stride_tile_V = c::nbatch_V2 + 4; + constexpr int stride_tile_Q = DKQ/2 + 4; + constexpr int stride_tile_K = nbatch_K2 + 4; + static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA"); + constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4; constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V; extern __shared__ half2 tile_Q[]; @@ -727,12 +924,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( __syncthreads(); - if (c::Q_in_reg) { + if constexpr (c::Q_in_reg) { const int j0 = (threadIdx.y / np) * cols_per_warp; #pragma unroll for (int k0 = 0; k0 < DKQ/2; k0 += tile_B::J) { - if (ntiles == 1) { + if constexpr (ntiles == 1) { load_ldmatrix(Q_B[k0/tile_B::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q); } else { #pragma unroll @@ -748,33 +945,33 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // Preload mask and K data for first iteration when using cp_async with multiple stages: if constexpr (nstages > 1) { - static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline"); + static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline"); constexpr bool use_cp_async = true; if (ncols2 > 1 || mask_h2) { flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async> (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask); } flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async> - (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, c::nbatch_K2, stride_K); + (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K); } // Iterate over ne11 == previous tokens: for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) { constexpr bool last_iter = false; - flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter> + flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter> (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); } { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. constexpr bool last_iter = true; - flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter> + flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter> (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); } // With multi-stage loading there is no __syncthreads at the end of the iter, // there can be a race condition on shared memory access for combining/writing back results. - if (nstages > 1 && nwarps*cols_per_warp > c::nbatch_fa) { + if constexpr (nstages > 1 && nwarps*cols_per_warp > c::nbatch_fa) { __syncthreads(); } @@ -796,7 +993,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. // So also write VKQ accumulators to shared memory in column-major format if np == 1. - constexpr int nbatch_combine = c::Q_in_reg ? DV/2 : DV/4; + constexpr int nbatch_combine = c::get_nbatch_combine_device(ncols); constexpr int tile_stride = nbatch_combine + 4; static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine"); @@ -873,10 +1070,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } #pragma unroll for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { - if (offset >= WARP_SIZE) { - continue; + if (offset < WARP_SIZE) { + KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); } - KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); } float KQ_cms[nmeta]; // KQ combine max scale per warp. @@ -892,10 +1088,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } #pragma unroll for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { - if (offset >= WARP_SIZE) { - continue; + if (offset < WARP_SIZE) { + KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); } - KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); } // Write back combined meta data: @@ -921,7 +1116,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) { - if (ntiles == 1) { + if constexpr (ntiles == 1) { const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data #pragma unroll for (int k0 = 0; k0 < nbatch_combine; k0 += tile_B::J) { @@ -1029,7 +1224,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #endif // INT8_MMA_AVAILABLE } -template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap> +template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla> __launch_bounds__(nwarps*WARP_SIZE, 1) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, @@ -1070,10 +1265,18 @@ static __global__ void flash_attn_ext_f16( #if defined(INT8_MMA_AVAILABLE) // Skip unused kernel variants for faster compilation: - if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { + if constexpr (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { + NO_DEVICE_CODE; + return; + } +#if __CUDA_ARCH__ == CC_TURING + if constexpr (ncols1*ncols2 > 32) { NO_DEVICE_CODE; return; } +#endif __CUDA_ARCH__ == CC_TURING + + static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV"); typedef fattn_mma_f16_config<DKQ, DV> c; @@ -1084,9 +1287,10 @@ static __global__ void flash_attn_ext_f16( const int stride_Q1 = nb01 / sizeof(float2); const int stride_Q2 = nb02 / sizeof(float2); const int stride_K = nb11 / sizeof(half2); - const int stride_V = nb21 / sizeof(half2); const int stride_mask = nb31 / sizeof(half2); + const int stride_V = mla ? stride_K : nb21 / sizeof(half2); + const int iter_k = ne11 / FATTN_KQ_STRIDE; const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; @@ -1109,10 +1313,11 @@ static __global__ void flash_attn_ext_f16( const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); + const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; const int kb0_start_kernel = kb0_start * kb_niter; @@ -1121,12 +1326,12 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. if (kb0_start == 0) { constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. - flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup> + flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup> (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } else { constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. - flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup> + flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup> (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } @@ -1147,10 +1352,11 @@ static __global__ void flash_attn_ext_f16( const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); // K and V have same shape const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); + const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; const int kb0_start_kernel = kb0_start * kb_niter; @@ -1158,7 +1364,7 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; - flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup> + flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup> (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); #else @@ -1176,6 +1382,7 @@ static __global__ void flash_attn_ext_f16( #endif // defined(INT8_MMA_AVAILABLE) } + template<int D, int ncols1, int ncols2> // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_stream_k_fixup( @@ -1310,7 +1517,7 @@ static __global__ void flash_attn_combine_results_new( } template <int DV, int ncols1, int ncols2> -void launch_fattn_new_mma( +static void launch_fattn_new_mma( ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE ) { @@ -1495,7 +1702,7 @@ void launch_fattn_new_mma( V_data, mask ? ((const char *) mask->data) : nullptr, !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, - scale, max_bias, m0, m1, n_head_log2, logit_softcap, + scale, max_bias, m0, m1, logit_softcap, n_head_log2, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, @@ -1529,17 +1736,13 @@ void launch_fattn_new_mma( template <int DKQ, int DV, int ncols1, int ncols2> -void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { +static void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const int id = ggml_cuda_get_device(); const int cc = ggml_cuda_info().devices[id].cc; typedef fattn_mma_f16_config<DKQ, DV> c; - constexpr int nbatch_K2 = c::nbatch_K2 < 1 ? DKQ/2 : c::nbatch_K2; - constexpr int nbatch_V2 = c::nbatch_V2 < 1 ? DV /2 : c::nbatch_V2; - constexpr int nbatch_combine = c::nbatch_combine < 1 ? DV /2 : c::nbatch_combine; - const int nstages = cp_async_available(cc) ? c::nstages_target : 0; constexpr int ncols = ncols1 * ncols2; @@ -1549,15 +1752,21 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I; constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max; + constexpr bool mla = DKQ == 576; + + const int nbatch_K2 = c::get_nbatch_K2_host (cc, ncols); + const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols); + const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols); + static_assert(DKQ % tile_B::J == 0, "bad DKQ"); static_assert(DV % tile_A::J == 0, "bad DV"); static_assert(ncols % cols_per_warp == 0, "bad ncols"); - const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(c::nbatch_K2 + 4, c::nbatch_V2 + 4) * sizeof(half2); - const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (c::nbatch_K2 + 4 + c::nbatch_V2 + 4) * sizeof(half2); - const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2); - const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2); - const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2); + const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2); + const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2); + const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2); const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage; @@ -1571,7 +1780,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml fattn_kernel_t fattn_kernel; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>; + fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>; #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; @@ -1582,7 +1791,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) } else { constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>; + fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>; #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; diff --git a/src/llama.cpp b/src/llama.cpp index 38a2b299..b4d42c84 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -15224,7 +15224,8 @@ struct llm_build_context { cb(kv_cache_trans, "kv_cache_trans", il); } - ggml_tensor * kvr = ggml_concat(ctx0, kv_compressed, ggml_permute(ctx0, k_rope, 0, 2, 1, 3), 0); + //ggml_tensor * kvr = ggml_concat(ctx0, kv_compressed, ggml_permute(ctx0, k_rope, 0, 2, 1, 3), 0); + ggml_tensor * kvr = ggml_concat(ctx0, ggml_permute(ctx0, k_rope, 0, 2, 1, 3), kv_compressed, 0); cb(kvr, "kvr", il); auto row_size = ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope); @@ -15240,7 +15241,8 @@ struct llm_build_context { if (lctx.cparams.mla_attn > 1 && lctx.cparams.flash_attn && pp_opt) { // PP for mla=2,3 - auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, kv_self.kv_l[il]->nb[1], 0); + auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, kv_self.kv_l[il]->nb[1], + ggml_row_size(kv_self.kv_l[il]->type, n_embd_head_qk_rope)); auto kv_f32_size = model.layers[il].wkv_b->ne[1] * kv_cache_nope->ne[1] * sizeof(float) / (1024*1024); int n_max_head = n_head; @@ -15254,7 +15256,7 @@ struct llm_build_context { auto n_per_head = model.layers[il].wkv_b->ne[1] / n_head; auto kv_cache_rope = ggml_view_3d(ctx0, kv_self.kv_l[il], n_embd_head_qk_rope, n_kv, 1, - kv_self.kv_l[il]->nb[1], kv_self.kv_l[il]->nb[2], ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)); + kv_self.kv_l[il]->nb[1], kv_self.kv_l[il]->nb[2], 0); //ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)); // There is still an issue with one or more of the ops GGML_OP_REPEAT, GGML_OP_CONCAT, GGML_OP_CPY on CUDA when // the KV cache is quantized. Hence, in that case we will simply use fp16 for now. @@ -15273,7 +15275,8 @@ struct llm_build_context { } cb(k_rope, "k_rope", il); - auto q = ggml_concat(ctx0, q_nope, q_rope, 0); + //auto q = ggml_concat(ctx0, q_nope, q_rope, 0); + auto q = ggml_concat(ctx0, q_rope, q_nope, 0); q = ggml_permute(ctx0, q, 0, 2, 1, 3); cb(q, "q_concat", il); @@ -15307,7 +15310,8 @@ struct llm_build_context { ggml_build_forward_expand(gf, k_nope); ggml_build_forward_expand(gf, v); - auto k = ggml_concat(ctx0, k_nope, k_rope, 0); + //auto k = ggml_concat(ctx0, k_nope, k_rope, 0); + auto k = ggml_concat(ctx0, k_rope, k_nope, 0); cb(k, "k", il); ggml_build_forward_expand(gf, k); @@ -15344,13 +15348,15 @@ struct llm_build_context { struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope); cb(q_nope2, "q_nope2", il); - ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0); + //ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0); + ggml_tensor * q = ggml_concat(ctx0, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), q_nope2, 0); cb(q, "q", il); if (lctx.cparams.flash_attn && (lctx.cparams.mla_attn == 1 || lctx.cparams.mla_attn == 3)) { ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, - ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0); + ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_self.kv_l[il]->type, n_embd_head_qk_rope)); cb(kv_cache_lora, "kv_cache_lora", il); kqv_compressed = ggml_flash_attn_ext(ctx0, q, kv_cache, kv_cache_lora, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f); @@ -15363,7 +15369,8 @@ struct llm_build_context { if (lctx.cparams.mla_attn > 1) { ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, - ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0); + ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_self.kv_l[il]->type, n_embd_head_qk_rope)); cb(kv_cache, "kv_cache_lora", il); kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache_lora)); |