summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-05-12 07:49:00 +0300
committerGitHub <noreply@github.com>2025-05-12 07:49:00 +0300
commit465569dff8b49a195450a0eb1974fd72a32fcebc (patch)
treeaf7f5b4af3738318a28ad9c9de722231c41c3d63
parent8669c3db2b98f05775292778dd05f424ee0cd250 (diff)
Faster DeepSeek FA on CUDA (#408)
* New DeepSeek FlashMLA Does not work because the RoPE portion is stored at the end in our case, while in mainline it is stored at the beginning, and the FA kernel assumes that. * Rearrange MLA K cache so it first new CUDA FA implementation * constexpr and minor changes --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/CMakeLists.txt2
-rw-r--r--ggml/src/ggml-cuda/cp-async.cuh10
-rw-r--r--ggml/src/ggml-cuda/fattn-new-mma.cu451
-rw-r--r--src/llama.cpp23
4 files changed, 356 insertions, 130 deletions
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
index 74ac5374..4f4337c2 100644
--- a/ggml/src/CMakeLists.txt
+++ b/ggml/src/CMakeLists.txt
@@ -986,7 +986,7 @@ endif()
set(CUDA_CXX_FLAGS "")
if (GGML_CUDA)
- set(CUDA_FLAGS -use_fast_math)
+ set(CUDA_FLAGS -use_fast_math -extended-lambda)
if (GGML_FATAL_WARNINGS)
list(APPEND CUDA_FLAGS -Werror all-warnings)
diff --git a/ggml/src/ggml-cuda/cp-async.cuh b/ggml/src/ggml-cuda/cp-async.cuh
index ecb65999..a87dc247 100644
--- a/ggml/src/ggml-cuda/cp-async.cuh
+++ b/ggml/src/ggml-cuda/cp-async.cuh
@@ -2,6 +2,16 @@
#include "common.cuh"
+static __device__ __forceinline__ unsigned int ggml_cuda_cvta_generic_to_shared(void * generic_ptr) {
+#ifdef CP_ASYNC_AVAILABLE
+ return __cvta_generic_to_shared(generic_ptr);
+#else
+ GGML_UNUSED(generic_ptr);
+ NO_DEVICE_CODE;
+ return 0;
+#endif // CP_ASYNC_AVAILABLE
+}
+
// Copies data from global to shared memory, cg == cache global.
// Both the src and dst pointers must be aligned to 16 bit.
// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.
diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu
index d1484451..630baf33 100644
--- a/ggml/src/ggml-cuda/fattn-new-mma.cu
+++ b/ggml/src/ggml-cuda/fattn-new-mma.cu
@@ -1,15 +1,16 @@
-// Adapted from https://github.com/ggml-org/llama.cpp/pull/13306
+// Adapted from https://github.com/ggml-org/llama.cpp/pull/13435
//
-// Copyright (C) 2023-2024 The ggml authors
-// Copyright (C) 2024 Iwan Kawrakow
+// Copyright (C) 2025 The ggml authors
+// Copyright (C) 2025 Iwan Kawrakow
// MIT license
// SPDX-License-Identifier: MIT
//
-#include "fattn-new-mma.cuh"
+#include "common.cuh"
#include "cp-async.cuh"
#include "mma_new.cuh"
#include "fattn-common.cuh"
+#include "fattn-new-mma.cuh"
using namespace ggml_cuda_mma;
@@ -39,6 +40,8 @@ struct fattn_mma_f16_config;
// The previous MMA version is better (faster)
// I'm keeping these around commented out for now,
// and only using the 576, 512 case.
+// Perhaps the 256 head size needs a closer look
+// to see if this implementation is better.
//
//template <>
//struct fattn_mma_f16_config< 64, 64> {
@@ -46,9 +49,30 @@ struct fattn_mma_f16_config;
// static constexpr int nwarps_max = 4;
// static constexpr bool Q_in_reg = true;
// static constexpr int nstages_target = 2;
-// static constexpr int nbatch_K2 = 32;
-// static constexpr int nbatch_V2 = 32;
-// static constexpr int nbatch_combine = 32;
+//
+// static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
+// return 32;
+// }
+//
+// static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
+// return 32;
+// }
+//
+// static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
+// return 32;
+// }
+//
+// static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
+// return 32;
+// }
+//
+// static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
+// return 32;
+// }
+//
+// static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
+// return 32;
+// }
//};
//
//template <>
@@ -57,9 +81,30 @@ struct fattn_mma_f16_config;
// static constexpr int nwarps_max = 4;
// static constexpr bool Q_in_reg = true;
// static constexpr int nstages_target = 2;
-// static constexpr int nbatch_K2 = 40;
-// static constexpr int nbatch_V2 = 40;
-// static constexpr int nbatch_combine = 40;
+//
+// static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
+// return 40;
+// }
+//
+// static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
+// return 40;
+// }
+//
+// static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
+// return 40;
+// }
+//
+// static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
+// return 40;
+// }
+//
+// static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
+// return 40;
+// }
+//
+// static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
+// return 40;
+// }
//};
//
//template <>
@@ -68,9 +113,30 @@ struct fattn_mma_f16_config;
// static constexpr int nwarps_max = 4;
// static constexpr bool Q_in_reg = true;
// static constexpr int nstages_target = 2;
-// static constexpr int nbatch_K2 = 48;
-// static constexpr int nbatch_V2 = 48;
-// static constexpr int nbatch_combine = 48;
+//
+// static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
+// return 48;
+// }
+//
+// static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
+// return 48;
+// }
+//
+// static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
+// return 48;
+// }
+//
+// static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
+// return 48;
+// }
+//
+// static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
+// return 48;
+// }
+//
+// static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
+// return 48;
+// }
//};
//
//template <>
@@ -79,9 +145,30 @@ struct fattn_mma_f16_config;
// static constexpr int nwarps_max = 4;
// static constexpr bool Q_in_reg = true;
// static constexpr int nstages_target = 2;
-// static constexpr int nbatch_K2 = 56;
-// static constexpr int nbatch_V2 = 56;
-// static constexpr int nbatch_combine = 56;
+//
+// static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
+// return 56;
+// }
+//
+// static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
+// return 56;
+// }
+//
+// static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
+// return 56;
+// }
+//
+// static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
+// return 56;
+// }
+//
+// static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
+// return 56;
+// }
+//
+// static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
+// return 56;
+// }
//};
//
//template <>
@@ -90,20 +177,30 @@ struct fattn_mma_f16_config;
// static constexpr int nwarps_max = 4;
// static constexpr bool Q_in_reg = true;
// static constexpr int nstages_target = 2;
-// static constexpr int nbatch_K2 = 64;
-// static constexpr int nbatch_V2 = 64;
-// static constexpr int nbatch_combine = 64;
-//};
//
-//template <>
-//struct fattn_mma_f16_config<192, 128> {
-// static constexpr int nbatch_fa = 64;
-// static constexpr int nwarps_max = 4;
-// static constexpr bool Q_in_reg = true;
-// static constexpr int nstages_target = 2;
-// static constexpr int nbatch_K2 = 96;
-// static constexpr int nbatch_V2 = 64;
-// static constexpr int nbatch_combine = 64;
+// static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
+// return 64;
+// }
+//
+// static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
+// return 64;
+// }
+//
+// static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
+// return 64;
+// }
+//
+// static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
+// return 64;
+// }
+//
+// static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
+// return 64;
+// }
+//
+// static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
+// return 64;
+// }
//};
//
//template <>
@@ -112,9 +209,38 @@ struct fattn_mma_f16_config;
// static constexpr int nwarps_max = 4;
// static constexpr bool Q_in_reg = true;
// static constexpr int nstages_target = 2;
-// static constexpr int nbatch_K2 = 128;
-// static constexpr int nbatch_V2 = 128;
-// static constexpr int nbatch_combine = 128;
+//
+// static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
+// return 128;
+// }
+//
+// static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
+// return 128;
+// }
+//
+// static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
+// return 128;
+// }
+//
+// static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
+// return 128;
+// }
+//
+// static int get_nbatch_combine_host(const int cc, const int ncols) {
+// if (ggml_cuda_highest_compiled_arch(cc) == CC_TURING) {
+// return ncols <= 16 ? 128 : 64;
+// }
+// return 64;
+// }
+//
+// static constexpr __device__ int get_nbatch_combine_device(int ncols) {
+//#if __CUDA_ARCH__ == CC_TURING
+// return ncols <= 16 ? 128 : 64;
+//#else
+// GGML_UNUSED(ncols);
+// return 128;
+//#endif // __CUDA_ARCH__ == CC_TURING
+// }
//};
template <>
@@ -123,9 +249,65 @@ struct fattn_mma_f16_config<576, 512> {
static constexpr int nwarps_max = 8;
static constexpr bool Q_in_reg = false;
static constexpr int nstages_target = 1;
- static constexpr int nbatch_K2 = 160;
- static constexpr int nbatch_V2 = 128;
- static constexpr int nbatch_combine = 128;
+
+ static int get_nbatch_K2_host(const int cc, const int ncols) {
+ if (ggml_cuda_highest_compiled_arch(cc) == CC_TURING) {
+ return ncols <= 16 ? 96 : 160;
+ }
+ return ncols <= 16 ? 288 : 160;
+ }
+
+ static constexpr __device__ int get_nbatch_K2_device(int ncols) {
+#if __CUDA_ARCH__ == CC_TURING
+ return ncols <= 16 ? 96 : 160;
+#else
+ return ncols <= 16 ? 288 : 160;
+#endif // __CUDA_ARCH__ == CC_TURING
+ }
+
+ static int get_nbatch_V2_host(const int cc, const int ncols) {
+ if (ggml_cuda_highest_compiled_arch(cc) == CC_TURING) {
+ return ncols <= 16 ? 64 : 128;
+ }
+ return ncols <= 16 ? 256 : 128;
+ }
+
+ static constexpr __device__ int get_nbatch_V2_device(int ncols) {
+#if __CUDA_ARCH__ == GML_CUDA_CC_TURING
+ return ncols <= 16 ? 64 : 128;
+#else
+ return ncols <= 16 ? 256 : 128;
+#endif // __CUDA_ARCH__ == GML_CUDA_CC_TURING
+ }
+
+ static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
+ return 128;
+ }
+
+ static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
+ return 128;
+ }
+};
+
+// ------------------------------------------------------------------------------------------------------------------
+
+// The compiler is always able to unroll loops if they contain continue expressions.
+// In such cases loop unrolling can still be achieved via recursion:
+template <int n>
+struct ggml_cuda_unroll {
+ template <typename Func, typename... Args>
+ __device__ void operator()(const Func & f, Args... args) const {
+ f(n - 1, args...);
+ ggml_cuda_unroll<n - 1>{}(f, args...);
+ }
+};
+
+template <>
+struct ggml_cuda_unroll<1> {
+ template <typename Func, typename... Args>
+ __device__ void operator()(const Func & f, Args... args) const {
+ f(0, args...);
+ }
};
template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async>
@@ -136,26 +318,25 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
// The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
if constexpr (use_cp_async) {
- const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV);
-
constexpr int preload = 64;
constexpr int h2_per_chunk = 16/sizeof(half2);
-
const int chunks_per_row = D2 / h2_per_chunk;
- int k0_start = 0;
-#pragma unroll
- for (int stride_k = WARP_SIZE; stride_k > WARP_SIZE/32; stride_k >>= 1) {
+ const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
+
+ auto load = [&] __device__ (auto n) {
+ const int stride_k = WARP_SIZE >> n;
+ const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
+ const int stride_i = WARP_SIZE / stride_k;
if (k0_start == k0_stop) {
- continue;
+ return;
}
- const int stride_i = WARP_SIZE / stride_k;
#pragma unroll
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
- const int i = i0 + threadIdx.y*stride_i + threadIdx.x / stride_k;
+ const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
break;
@@ -168,18 +349,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);
}
}
- k0_start = k0_stop;
- }
+ };
+ ggml_cuda_unroll<5>{}(load);
} else {
static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds");
-#pragma unroll
- for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
+ auto load = [&] __device__ (const int n) {
+ const int stride_k = WARP_SIZE >> n;
const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k);
const int k0_stop = D2 - D2 % (1*stride_k);
const int stride_i = WARP_SIZE / stride_k;
- if (k0_start == k0_stop || k0_stop <= 0) {
- continue;
+ if (k0_start == k0_stop) {
+ return;
}
#pragma unroll
@@ -197,7 +378,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
tile_KV[i*stride_tile + k] = KV[i*stride_KV + k];
}
}
- }
+ };
+ ggml_cuda_unroll<3>{}(load);
}
}
@@ -211,7 +393,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa;
constexpr int stride_j = nwarps * cols_per_warp;
- const unsigned int tile_mask_32 = __cvta_generic_to_shared(tile_mask);
+ const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask);
#pragma unroll
for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
@@ -245,7 +427,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
}
}
-template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter>
+template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
const float2 * const __restrict__ Q_f2,
const half2 * const __restrict__ K_h2,
@@ -283,10 +465,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
constexpr int cols_per_warp = ntiles * tile_B::I;
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+ constexpr int ncols = ncols1 * ncols2;
+ constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
+ constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
- constexpr int stride_tile_Q = DKQ/2 + 4;
- constexpr int stride_tile_K = c::nbatch_K2 + 4;
- constexpr int stride_tile_V = c::nbatch_V2 + 4;
+ constexpr int stride_tile_Q = DKQ/2 + 4;
+ constexpr int stride_tile_K = nbatch_K2 + 4;
+
+ static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
+ constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
const int k_VKQ_0 = kb0 * c::nbatch_fa;
tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles];
@@ -297,29 +484,30 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C;
if constexpr (nstages > 1) {
- static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
+ static_assert(!mla, "multi-stage loading not implemented for MLA");
+ static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
constexpr bool use_cp_async = true;
cp_async_wait_all();
__syncthreads();
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
- (V_h2 + k_VKQ_0*stride_V, tile_V, c::nbatch_V2, stride_V);
+ (V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
} else {
constexpr bool use_cp_async = nstages == 1;
- if (ncols2 > 1 || mask_h2) {
+ if constexpr (ncols2 > 1 || mask_h2) {
flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask);
}
}
#pragma unroll
- for (int k0_start = 0; k0_start < DKQ/2; k0_start += c::nbatch_K2) {
- const int k0_stop = k0_start + c::nbatch_K2 < DKQ/2 ? k0_start + c::nbatch_K2 : DKQ/2;
+ for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) {
+ const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
const int k0_diff = k0_stop - k0_start;
- if (nstages <= 1) {
+ if constexpr (nstages <= 1) {
constexpr bool use_cp_async = nstages == 1;
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
(K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K);
- if (use_cp_async) {
+ if constexpr (use_cp_async) {
cp_async_wait_all();
}
__syncthreads();
@@ -334,7 +522,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) {
tile_A K_A;
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
- if (ntiles == 1) {
+ if constexpr (ntiles == 1) {
mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]);
} else {
#pragma unroll
@@ -364,12 +552,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
}
}
- if (nstages <= 1) {
+ if constexpr (nstages <= 1) {
__syncthreads(); // Only needed if tile_K == tile_V.
}
}
- if (use_logit_softcap) {
+ if constexpr (use_logit_softcap) {
static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
#pragma unroll
for (int i = 0; i < c::nbatch_fa/(np*tile_C_KQ::I) * ntiles; ++i) {
@@ -387,8 +575,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
}
float KQ_rowsum_add[cols_per_thread] = {0.0f};
- if (ntiles == 1) {
- if (ncols2 > 1 || mask_h2) {
+ if constexpr (ntiles == 1) {
+ if constexpr (ncols2 > 1 || mask_h2) {
#pragma unroll
for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) {
const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
@@ -506,7 +694,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
}
- if (ntiles == 1) {
+ if constexpr (ntiles == 1) {
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
#pragma unroll
for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
@@ -534,7 +722,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
tile_B B[c::nbatch_fa/(np*2*tile_B::J) * ntiles];
tile_B_16 * B_16 = (tile_B_16 *) B;
static_assert(c::nbatch_fa % (np*2*tile_B::J) == 0, "bad loop size");
- if (ntiles == 1) {
+ if constexpr (ntiles == 1) {
#pragma unroll
for (int k = 0; k < c::nbatch_fa/(np*2*tile_B::J); ++k) {
B[k] = get_transposed(get_half2(KQ_C[k]));
@@ -548,7 +736,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
}
}
- if (nstages > 1) {
+ if constexpr (nstages > 1) {
// Preload K tile for next iteration:
constexpr bool use_cp_async = true;
cp_async_wait_all();
@@ -559,24 +747,30 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
(mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
}
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
- (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, c::nbatch_K2, stride_K);
+ (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
}
}
+
+ // For MLA K and V have the same data.
+ // Therefore, iterate over V in reverse and re-use the data if possible.
+ static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
+ constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
#pragma unroll
- for (int i0_start = 0; i0_start < DV; i0_start += 2*c::nbatch_V2) {
- const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV;
- const int i0_diff = i0_stop - i0_start;
+ for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) {
+ const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0;
+ const int i0_diff = i0_stop - i0_start;
- if (nstages == 1) {
+ if (nstages <= 1 && i0_start < reusable_cutoff) {
constexpr bool use_cp_async = nstages == 1;
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
(V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
- if (use_cp_async) {
+ if constexpr (use_cp_async) {
cp_async_wait_all();
}
__syncthreads();
}
+ const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
// Calculate VKQ tile:
#pragma unroll
@@ -587,8 +781,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
tile_A A;
- load_ldmatrix_trans(A, tile_V + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
- if (ntiles == 1) {
+ load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
+ if constexpr (ntiles == 1) {
mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
} else {
#pragma unroll
@@ -600,7 +794,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
}
}
- if (nstages <= 1) {
+ if constexpr (nstages <= 1) {
__syncthreads(); // Only needed if tile_K == tile_V.
}
}
@@ -618,7 +812,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
#endif // INT8_MMA_AVAILABLE
}
-template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
+template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const float2 * const __restrict__ Q_f2,
const half2 * const __restrict__ K_h2,
@@ -654,13 +848,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
constexpr int cols_per_warp = ntiles * tile_B::I;
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+ constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
+ constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
- constexpr int stride_tile_Q = DKQ/2 + 4;
- constexpr int stride_tile_K = c::nbatch_K2 + 4;
- constexpr int stride_tile_V = c::nbatch_V2 + 4;
+ constexpr int stride_tile_Q = DKQ/2 + 4;
+ constexpr int stride_tile_K = nbatch_K2 + 4;
+ static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
+ constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
extern __shared__ half2 tile_Q[];
@@ -727,12 +924,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
__syncthreads();
- if (c::Q_in_reg) {
+ if constexpr (c::Q_in_reg) {
const int j0 = (threadIdx.y / np) * cols_per_warp;
#pragma unroll
for (int k0 = 0; k0 < DKQ/2; k0 += tile_B::J) {
- if (ntiles == 1) {
+ if constexpr (ntiles == 1) {
load_ldmatrix(Q_B[k0/tile_B::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q);
} else {
#pragma unroll
@@ -748,33 +945,33 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
// Preload mask and K data for first iteration when using cp_async with multiple stages:
if constexpr (nstages > 1) {
- static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
+ static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
constexpr bool use_cp_async = true;
if (ncols2 > 1 || mask_h2) {
flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
(mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
}
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
- (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, c::nbatch_K2, stride_K);
+ (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
}
// Iterate over ne11 == previous tokens:
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
constexpr bool last_iter = false;
- flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
+ flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
}
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
constexpr bool last_iter = true;
- flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
+ flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
}
// With multi-stage loading there is no __syncthreads at the end of the iter,
// there can be a race condition on shared memory access for combining/writing back results.
- if (nstages > 1 && nwarps*cols_per_warp > c::nbatch_fa) {
+ if constexpr (nstages > 1 && nwarps*cols_per_warp > c::nbatch_fa) {
__syncthreads();
}
@@ -796,7 +993,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
// It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
// So also write VKQ accumulators to shared memory in column-major format if np == 1.
- constexpr int nbatch_combine = c::Q_in_reg ? DV/2 : DV/4;
+ constexpr int nbatch_combine = c::get_nbatch_combine_device(ncols);
constexpr int tile_stride = nbatch_combine + 4;
static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine");
@@ -873,10 +1070,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
}
#pragma unroll
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
- if (offset >= WARP_SIZE) {
- continue;
+ if (offset < WARP_SIZE) {
+ KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
}
- KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
}
float KQ_cms[nmeta]; // KQ combine max scale per warp.
@@ -892,10 +1088,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
}
#pragma unroll
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
- if (offset >= WARP_SIZE) {
- continue;
+ if (offset < WARP_SIZE) {
+ KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
}
- KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
}
// Write back combined meta data:
@@ -921,7 +1116,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
#pragma unroll
for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) {
- if (ntiles == 1) {
+ if constexpr (ntiles == 1) {
const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data
#pragma unroll
for (int k0 = 0; k0 < nbatch_combine; k0 += tile_B::J) {
@@ -1029,7 +1224,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
#endif // INT8_MMA_AVAILABLE
}
-template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap>
+template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla>
__launch_bounds__(nwarps*WARP_SIZE, 1)
static __global__ void flash_attn_ext_f16(
const char * __restrict__ Q,
@@ -1070,10 +1265,18 @@ static __global__ void flash_attn_ext_f16(
#if defined(INT8_MMA_AVAILABLE)
// Skip unused kernel variants for faster compilation:
- if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
+ if constexpr (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
+ NO_DEVICE_CODE;
+ return;
+ }
+#if __CUDA_ARCH__ == CC_TURING
+ if constexpr (ncols1*ncols2 > 32) {
NO_DEVICE_CODE;
return;
}
+#endif __CUDA_ARCH__ == CC_TURING
+
+ static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
typedef fattn_mma_f16_config<DKQ, DV> c;
@@ -1084,9 +1287,10 @@ static __global__ void flash_attn_ext_f16(
const int stride_Q1 = nb01 / sizeof(float2);
const int stride_Q2 = nb02 / sizeof(float2);
const int stride_K = nb11 / sizeof(half2);
- const int stride_V = nb21 / sizeof(half2);
const int stride_mask = nb31 / sizeof(half2);
+ const int stride_V = mla ? stride_K : nb21 / sizeof(half2);
+
const int iter_k = ne11 / FATTN_KQ_STRIDE;
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
@@ -1109,10 +1313,11 @@ static __global__ void flash_attn_ext_f16(
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
- const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
+ const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
const int kb0_start_kernel = kb0_start * kb_niter;
@@ -1121,12 +1326,12 @@ static __global__ void flash_attn_ext_f16(
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
if (kb0_start == 0) {
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
- flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
} else {
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
- flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
}
@@ -1147,10 +1352,11 @@ static __global__ void flash_attn_ext_f16(
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
- const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); // K and V have same shape
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
+ const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
const int kb0_start_kernel = kb0_start * kb_niter;
@@ -1158,7 +1364,7 @@ static __global__ void flash_attn_ext_f16(
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
constexpr bool needs_fixup = false;
- flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
#else
@@ -1176,6 +1382,7 @@ static __global__ void flash_attn_ext_f16(
#endif // defined(INT8_MMA_AVAILABLE)
}
+
template<int D, int ncols1, int ncols2> // D == head size
__launch_bounds__(D, 1)
static __global__ void flash_attn_stream_k_fixup(
@@ -1310,7 +1517,7 @@ static __global__ void flash_attn_combine_results_new(
}
template <int DV, int ncols1, int ncols2>
-void launch_fattn_new_mma(
+static void launch_fattn_new_mma(
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
) {
@@ -1495,7 +1702,7 @@ void launch_fattn_new_mma(
V_data,
mask ? ((const char *) mask->data) : nullptr,
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
- scale, max_bias, m0, m1, n_head_log2, logit_softcap,
+ scale, max_bias, m0, m1, logit_softcap, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
@@ -1529,17 +1736,13 @@ void launch_fattn_new_mma(
template <int DKQ, int DV, int ncols1, int ncols2>
-void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+static void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const int id = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[id].cc;
typedef fattn_mma_f16_config<DKQ, DV> c;
- constexpr int nbatch_K2 = c::nbatch_K2 < 1 ? DKQ/2 : c::nbatch_K2;
- constexpr int nbatch_V2 = c::nbatch_V2 < 1 ? DV /2 : c::nbatch_V2;
- constexpr int nbatch_combine = c::nbatch_combine < 1 ? DV /2 : c::nbatch_combine;
-
const int nstages = cp_async_available(cc) ? c::nstages_target : 0;
constexpr int ncols = ncols1 * ncols2;
@@ -1549,15 +1752,21 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I;
constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max;
+ constexpr bool mla = DKQ == 576;
+
+ const int nbatch_K2 = c::get_nbatch_K2_host (cc, ncols);
+ const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols);
+ const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols);
+
static_assert(DKQ % tile_B::J == 0, "bad DKQ");
static_assert(DV % tile_A::J == 0, "bad DV");
static_assert(ncols % cols_per_warp == 0, "bad ncols");
- const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(c::nbatch_K2 + 4, c::nbatch_V2 + 4) * sizeof(half2);
- const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (c::nbatch_K2 + 4 + c::nbatch_V2 + 4) * sizeof(half2);
- const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
- const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2);
- const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
+ const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
+ const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
+ const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
+ const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2);
+ const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;
@@ -1571,7 +1780,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
fattn_kernel_t fattn_kernel;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
- fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
+ fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
@@ -1582,7 +1791,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
} else {
constexpr bool use_logit_softcap = true;
- fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
+ fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
diff --git a/src/llama.cpp b/src/llama.cpp
index 38a2b299..b4d42c84 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -15224,7 +15224,8 @@ struct llm_build_context {
cb(kv_cache_trans, "kv_cache_trans", il);
}
- ggml_tensor * kvr = ggml_concat(ctx0, kv_compressed, ggml_permute(ctx0, k_rope, 0, 2, 1, 3), 0);
+ //ggml_tensor * kvr = ggml_concat(ctx0, kv_compressed, ggml_permute(ctx0, k_rope, 0, 2, 1, 3), 0);
+ ggml_tensor * kvr = ggml_concat(ctx0, ggml_permute(ctx0, k_rope, 0, 2, 1, 3), kv_compressed, 0);
cb(kvr, "kvr", il);
auto row_size = ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope);
@@ -15240,7 +15241,8 @@ struct llm_build_context {
if (lctx.cparams.mla_attn > 1 && lctx.cparams.flash_attn && pp_opt) { // PP for mla=2,3
- auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, kv_self.kv_l[il]->nb[1], 0);
+ auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, kv_self.kv_l[il]->nb[1],
+ ggml_row_size(kv_self.kv_l[il]->type, n_embd_head_qk_rope));
auto kv_f32_size = model.layers[il].wkv_b->ne[1] * kv_cache_nope->ne[1] * sizeof(float) / (1024*1024);
int n_max_head = n_head;
@@ -15254,7 +15256,7 @@ struct llm_build_context {
auto n_per_head = model.layers[il].wkv_b->ne[1] / n_head;
auto kv_cache_rope = ggml_view_3d(ctx0, kv_self.kv_l[il], n_embd_head_qk_rope, n_kv, 1,
- kv_self.kv_l[il]->nb[1], kv_self.kv_l[il]->nb[2], ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank));
+ kv_self.kv_l[il]->nb[1], kv_self.kv_l[il]->nb[2], 0); //ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank));
// There is still an issue with one or more of the ops GGML_OP_REPEAT, GGML_OP_CONCAT, GGML_OP_CPY on CUDA when
// the KV cache is quantized. Hence, in that case we will simply use fp16 for now.
@@ -15273,7 +15275,8 @@ struct llm_build_context {
}
cb(k_rope, "k_rope", il);
- auto q = ggml_concat(ctx0, q_nope, q_rope, 0);
+ //auto q = ggml_concat(ctx0, q_nope, q_rope, 0);
+ auto q = ggml_concat(ctx0, q_rope, q_nope, 0);
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
cb(q, "q_concat", il);
@@ -15307,7 +15310,8 @@ struct llm_build_context {
ggml_build_forward_expand(gf, k_nope);
ggml_build_forward_expand(gf, v);
- auto k = ggml_concat(ctx0, k_nope, k_rope, 0);
+ //auto k = ggml_concat(ctx0, k_nope, k_rope, 0);
+ auto k = ggml_concat(ctx0, k_rope, k_nope, 0);
cb(k, "k", il);
ggml_build_forward_expand(gf, k);
@@ -15344,13 +15348,15 @@ struct llm_build_context {
struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope);
cb(q_nope2, "q_nope2", il);
- ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0);
+ //ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0);
+ ggml_tensor * q = ggml_concat(ctx0, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), q_nope2, 0);
cb(q, "q", il);
if (lctx.cparams.flash_attn && (lctx.cparams.mla_attn == 1 || lctx.cparams.mla_attn == 3)) {
ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il],
kv_lora_rank, n_kv,
- ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0);
+ ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope),
+ ggml_row_size(kv_self.kv_l[il]->type, n_embd_head_qk_rope));
cb(kv_cache_lora, "kv_cache_lora", il);
kqv_compressed = ggml_flash_attn_ext(ctx0, q, kv_cache, kv_cache_lora, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f);
@@ -15363,7 +15369,8 @@ struct llm_build_context {
if (lctx.cparams.mla_attn > 1) {
ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il],
kv_lora_rank, n_kv,
- ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0);
+ ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope),
+ ggml_row_size(kv_self.kv_l[il]->type, n_embd_head_qk_rope));
cb(kv_cache, "kv_cache_lora", il);
kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache_lora));