summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml/src/ggml-cuda.cu5
-rw-r--r--ggml/src/ggml-cuda/fattn-new-mma.cu1705
-rw-r--r--ggml/src/ggml-cuda/fattn-new-mma.cuh3
-rw-r--r--ggml/src/ggml-cuda/fattn.cu21
-rw-r--r--ggml/src/ggml-cuda/mmvq.cu109
5 files changed, 1798 insertions, 45 deletions
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index 1f62b882..ff6e064c 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -3587,6 +3587,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
return (op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) ||
(op->src[1]->type == GGML_TYPE_Q8_0 && op->src[2]->type == GGML_TYPE_Q8_0);
}
+ if (op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512) {
+ const int cc = ggml_cuda_info().devices[cuda_ctx->device].cc;
+ int gqa = op->src[0]->ne[2]/op->src[1]->ne[2];
+ return (new_mma_available(cc) && cc >= CC_AMPERE && op->src[3] && gqa%16 == 0);
+ }
if (op->src[1]->ne[0] > 256) {
return false;
}
diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu
new file mode 100644
index 00000000..796d9c7b
--- /dev/null
+++ b/ggml/src/ggml-cuda/fattn-new-mma.cu
@@ -0,0 +1,1705 @@
+// Adapted from https://github.com/ggml-org/llama.cpp/pull/13306
+//
+// Copyright (C) 2023-2024 The ggml authors
+// Copyright (C) 2024 Iwan Kawrakow
+// MIT license
+// SPDX-License-Identifier: MIT
+//
+
+#include "fattn-new-mma.cuh"
+#include "cp-async.cuh"
+#include "mma_new.cuh"
+#include "fattn-common.cuh"
+
+using namespace ggml_cuda_mma;
+
+typedef tile<16, 8, half2> tile_A;
+typedef tile< 8, 8, half2> tile_B;
+typedef tile<16, 8, half2> tile_B_16;
+typedef tile<16, 8, float> tile_C_KQ;
+typedef tile<16, 16, float> tile_C_KQ_16;
+typedef tile<16, 4, half2> tile_C_VKQ;
+typedef tile<16, 8, half2> tile_C_VKQ_16;
+
+// Config options for specific head sizes.
+// Should not affect results, only speed/register pressure/shared memory use.
+//
+// nbatch_fa: number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators.
+// nwarps_max: maximum number of warps per CUDA block, up to 8 warps in total can run per SM (given enough shared memory).
+// Q_in_reg: whether the Q values should be kept permanently in registers.
+// nstages_target: targeted number of pipeline stages for cp_async (if available), 0 means synchronous data loading.
+// nbatch_K2: number of K half2 values in direction of DKQ to load in parallel.
+// nbatch_V2: number of V half2 values in direction of DV to load in parallel.
+// nbatch_combine: number of VKQ half2 values in direction of DV to combine in parallel.
+
+template <int DKQ, int DV>
+struct fattn_mma_f16_config;
+
+//
+// The previous MMA version is better (faster)
+// I'm keeping these around commented out for now,
+// and only using the 576, 512 case.
+//
+//template <>
+//struct fattn_mma_f16_config< 64, 64> {
+// static constexpr int nbatch_fa = 64;
+// static constexpr int nwarps_max = 4;
+// static constexpr bool Q_in_reg = true;
+// static constexpr int nstages_target = 2;
+// static constexpr int nbatch_K2 = 32;
+// static constexpr int nbatch_V2 = 32;
+// static constexpr int nbatch_combine = 32;
+//};
+//
+//template <>
+//struct fattn_mma_f16_config< 80, 80> {
+// static constexpr int nbatch_fa = 64;
+// static constexpr int nwarps_max = 4;
+// static constexpr bool Q_in_reg = true;
+// static constexpr int nstages_target = 2;
+// static constexpr int nbatch_K2 = 40;
+// static constexpr int nbatch_V2 = 40;
+// static constexpr int nbatch_combine = 40;
+//};
+//
+//template <>
+//struct fattn_mma_f16_config< 96, 96> {
+// static constexpr int nbatch_fa = 64;
+// static constexpr int nwarps_max = 4;
+// static constexpr bool Q_in_reg = true;
+// static constexpr int nstages_target = 2;
+// static constexpr int nbatch_K2 = 48;
+// static constexpr int nbatch_V2 = 48;
+// static constexpr int nbatch_combine = 48;
+//};
+//
+//template <>
+//struct fattn_mma_f16_config<112, 112> {
+// static constexpr int nbatch_fa = 64;
+// static constexpr int nwarps_max = 4;
+// static constexpr bool Q_in_reg = true;
+// static constexpr int nstages_target = 2;
+// static constexpr int nbatch_K2 = 56;
+// static constexpr int nbatch_V2 = 56;
+// static constexpr int nbatch_combine = 56;
+//};
+//
+//template <>
+//struct fattn_mma_f16_config<128, 128> {
+// static constexpr int nbatch_fa = 64;
+// static constexpr int nwarps_max = 4;
+// static constexpr bool Q_in_reg = true;
+// static constexpr int nstages_target = 2;
+// static constexpr int nbatch_K2 = 64;
+// static constexpr int nbatch_V2 = 64;
+// static constexpr int nbatch_combine = 64;
+//};
+//
+//template <>
+//struct fattn_mma_f16_config<192, 128> {
+// static constexpr int nbatch_fa = 64;
+// static constexpr int nwarps_max = 4;
+// static constexpr bool Q_in_reg = true;
+// static constexpr int nstages_target = 2;
+// static constexpr int nbatch_K2 = 96;
+// static constexpr int nbatch_V2 = 64;
+// static constexpr int nbatch_combine = 64;
+//};
+//
+//template <>
+//struct fattn_mma_f16_config<256, 256> {
+// static constexpr int nbatch_fa = 32;
+// static constexpr int nwarps_max = 4;
+// static constexpr bool Q_in_reg = true;
+// static constexpr int nstages_target = 2;
+// static constexpr int nbatch_K2 = 128;
+// static constexpr int nbatch_V2 = 128;
+// static constexpr int nbatch_combine = 128;
+//};
+
+template <>
+struct fattn_mma_f16_config<576, 512> {
+ static constexpr int nbatch_fa = 32;
+ static constexpr int nwarps_max = 8;
+ static constexpr bool Q_in_reg = false;
+ static constexpr int nstages_target = 1;
+ static constexpr int nbatch_K2 = 160;
+ static constexpr int nbatch_V2 = 128;
+ static constexpr int nbatch_combine = 128;
+};
+
+template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async>
+static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
+ const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) {
+
+ // K/V data is loaded with decreasing granularity for D for better memory bandwidth.
+ // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
+
+ if constexpr (use_cp_async) {
+ const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV);
+
+ constexpr int preload = 64;
+ constexpr int h2_per_chunk = 16/sizeof(half2);
+
+ const int chunks_per_row = D2 / h2_per_chunk;
+
+ int k0_start = 0;
+#pragma unroll
+ for (int stride_k = WARP_SIZE; stride_k > WARP_SIZE/32; stride_k >>= 1) {
+ const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
+
+ if (k0_start == k0_stop) {
+ continue;
+ }
+
+ const int stride_i = WARP_SIZE / stride_k;
+#pragma unroll
+ for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
+ const int i = i0 + threadIdx.y*stride_i + threadIdx.x / stride_k;
+
+ if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
+ break;
+ }
+
+#pragma unroll
+ for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
+ const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+
+ cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);
+ }
+ }
+ k0_start = k0_stop;
+ }
+ } else {
+ static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds");
+#pragma unroll
+ for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
+ const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k);
+ const int k0_stop = D2 - D2 % (1*stride_k);
+ const int stride_i = WARP_SIZE / stride_k;
+
+ if (k0_start == k0_stop || k0_stop <= 0) {
+ continue;
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
+ const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+
+ if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
+ break;
+ }
+
+#pragma unroll
+ for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
+ const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+
+ tile_KV[i*stride_tile + k] = KV[i*stride_KV + k];
+ }
+ }
+ }
+ }
+}
+
+template<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async>
+static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
+ const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) {
+ static_assert(nbatch_fa == 2*WARP_SIZE || WARP_SIZE % nbatch_fa == 0, "bad KQ_per_iter");
+
+ if constexpr (use_cp_async) {
+ constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64;
+ constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa;
+ constexpr int stride_j = nwarps * cols_per_warp;
+
+ const unsigned int tile_mask_32 = __cvta_generic_to_shared(tile_mask);
+
+#pragma unroll
+ for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
+ const int j = j0 + threadIdx.y*cols_per_warp +
+ (nbatch_fa == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/cols_per_warp));
+
+ if (j0 + stride_j > ncols1 && j >= ncols1) {
+ break;
+ }
+
+ const int i = 4 * (threadIdx.x % (nbatch_fa/8));
+
+ cp_async_cg_16<preload>(tile_mask_32 + j*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i);
+ }
+ return;
+ }
+
+ constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa;
+ constexpr int stride_j = nwarps * cols_per_warp;
+#pragma unroll
+ for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
+ const int j = j0 + threadIdx.y*cols_per_warp + (nbatch_fa == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/cols_per_warp));
+
+ if (j0 + stride_j > ncols1 && j >= ncols1) {
+ break;
+ }
+
+ const int i = nbatch_fa == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/cols_per_warp);
+
+ tile_mask[j*(nbatch_fa/2 + 4) + i] = mask_h2[j*stride_mask + i];
+ }
+}
+
+template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter>
+static __device__ __forceinline__ void flash_attn_ext_f16_iter(
+ const float2 * const __restrict__ Q_f2,
+ const half2 * const __restrict__ K_h2,
+ const half2 * const __restrict__ V_h2,
+ const half2 * const __restrict__ mask_h2,
+ float2 * const __restrict__ dstk,
+ float2 * const __restrict__ dstk_fixup,
+ const float scale,
+ const float slope,
+ const float logit_softcap,
+ const int ne01,
+ const int ne02,
+ const int stride_K,
+ const int stride_V,
+ const int stride_mask,
+ const int jt,
+ half2 * const __restrict__ tile_Q,
+ half2 * const __restrict__ tile_K,
+ half2 * const __restrict__ tile_V,
+ half2 * const __restrict__ tile_mask,
+ const tile_B * const __restrict__ Q_B,
+ tile_C_VKQ * const __restrict__ VKQ_C,
+ float * const __restrict__ KQ_max,
+ float * const __restrict__ KQ_rowsum,
+ const int kb0) {
+#ifdef INT8_MMA_AVAILABLE
+ typedef fattn_mma_f16_config<DKQ, DV> c;
+
+#ifdef CP_ASYNC_AVAILABLE
+ constexpr int nstages = c::nstages_target;
+#else
+ constexpr int nstages = 0;
+#endif // CP_ASYNC_AVAILABLE
+
+ constexpr int cols_per_warp = ntiles * tile_B::I;
+ constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
+ constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+
+ constexpr int stride_tile_Q = DKQ/2 + 4;
+ constexpr int stride_tile_K = c::nbatch_K2 + 4;
+ constexpr int stride_tile_V = c::nbatch_V2 + 4;
+
+ const int k_VKQ_0 = kb0 * c::nbatch_fa;
+ tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles];
+
+ // Use wide variants of tiles if ntiles >= 2.
+ tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B;
+ tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
+ tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C;
+
+ if constexpr (nstages > 1) {
+ static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
+ constexpr bool use_cp_async = true;
+ cp_async_wait_all();
+ __syncthreads();
+ flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
+ (V_h2 + k_VKQ_0*stride_V, tile_V, c::nbatch_V2, stride_V);
+ } else {
+ constexpr bool use_cp_async = nstages == 1;
+ if (ncols2 > 1 || mask_h2) {
+ flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask);
+ }
+ }
+
+#pragma unroll
+ for (int k0_start = 0; k0_start < DKQ/2; k0_start += c::nbatch_K2) {
+ const int k0_stop = k0_start + c::nbatch_K2 < DKQ/2 ? k0_start + c::nbatch_K2 : DKQ/2;
+ const int k0_diff = k0_stop - k0_start;
+
+ if (nstages <= 1) {
+ constexpr bool use_cp_async = nstages == 1;
+ flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
+ (K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K);
+ if (use_cp_async) {
+ cp_async_wait_all();
+ }
+ __syncthreads();
+ }
+
+ // Calculate tile of KQ:
+ if constexpr (c::Q_in_reg) {
+#pragma unroll
+ for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) {
+ const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
+#pragma unroll
+ for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) {
+ tile_A K_A;
+ load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
+ if (ntiles == 1) {
+ mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]);
+ } else {
+#pragma unroll
+ for (int t = 0; t < ntiles/2; ++t) {
+ // Wide version of KQ_C is column-major => swap A and B.
+ mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A);
+ }
+ }
+ }
+ }
+ } else {
+ static_assert(ntiles == 2, "ntiles != 2 not implemented");
+#pragma unroll
+ for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) {
+ load_ldmatrix(Q_B_16[0], tile_Q + (threadIdx.y / np)*(tile_B_16::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
+
+#pragma unroll
+ for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) {
+ const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
+
+ tile_A K_A;
+ load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
+
+ // Wide version of KQ_C is column-major => swap A and B.
+ mma(KQ_C_16[i_KQ_00/(np*tile_A::I)], Q_B_16[0], K_A);
+ }
+ }
+ }
+
+ if (nstages <= 1) {
+ __syncthreads(); // Only needed if tile_K == tile_V.
+ }
+ }
+
+ if (use_logit_softcap) {
+ static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
+#pragma unroll
+ for (int i = 0; i < c::nbatch_fa/(np*tile_C_KQ::I) * ntiles; ++i) {
+#pragma unroll
+ for (int l = 0; l < tile_C_KQ::ne; ++l) {
+ KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
+ }
+ }
+ }
+
+ float KQ_max_new[cols_per_thread];
+#pragma unroll
+ for (int col = 0; col < cols_per_thread; ++col) {
+ KQ_max_new[col] = KQ_max[col];
+ }
+ float KQ_rowsum_add[cols_per_thread] = {0.0f};
+
+ if (ntiles == 1) {
+ if (ncols2 > 1 || mask_h2) {
+#pragma unroll
+ for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) {
+ const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
+#pragma unroll
+ for (int l = 0; l < tile_C_KQ::ne; ++l) {
+ const int i = i0 + tile_C_KQ::get_i(l);
+ const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2;
+
+ KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope *
+ __half2float(((const half *) tile_mask)[j*(c::nbatch_fa + 8) + i]);
+ }
+ }
+ }
+
+ // Calculate softmax for each KQ column using the current max. value.
+ // The divisor is stored in KQ_rowsum and will be applied at the end.
+ static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
+#pragma unroll
+ for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) {
+#pragma unroll
+ for (int l = 0; l < tile_C_KQ::ne; ++l) {
+ KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]);
+ }
+ }
+
+ // Values per KQ column are spread across 8 threads, does not need full warp reduce:
+#pragma unroll
+ for (int col = 0; col < cols_per_thread; ++col) {
+#pragma unroll
+ for (int offset = 16; offset >= 4; offset >>= 1) {
+ KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
+ }
+ }
+
+ static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
+#pragma unroll
+ for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) {
+#pragma unroll
+ for (int l = 0; l < tile_C_KQ::ne; ++l) {
+ KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]);
+
+ KQ_rowsum_add[l % 2] += KQ_C[k].x[l];
+ }
+ }
+ } else { // ntiles > 1
+ if (ncols2 > 1 || mask_h2) {
+#pragma unroll
+ for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ_16::J) {
+ const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J;
+#pragma unroll
+ for (int t = 0; t < ntiles/2; ++t) {
+#pragma unroll
+ for (int l0 = 0; l0 < tile_C_KQ_16::ne; l0 += 2) {
+ const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2;
+ const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2;
+
+ const float2 tmp = __half22float2(tile_mask[j*(c::nbatch_fa/2 + 4) + i]);
+ const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t;
+ KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x;
+ KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y;
+ }
+ }
+ }
+ }
+
+ // Calculate softmax for each KQ column using the current max. value.
+ // The divisor is stored in KQ_rowsum and will be applied at the end.
+ static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
+#pragma unroll
+ for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) {
+#pragma unroll
+ for (int t = 0; t < ntiles/2; ++t) {
+#pragma unroll
+ for (int l = 0; l < tile_C_KQ_16::ne; ++l) {
+ const int KQ_index = 2*t + (l/2) % 2;
+ KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l]);
+ }
+ }
+ }
+
+ // Values per KQ column are spread across 4 threads, does not need full warp reduce:
+#pragma unroll
+ for (int col = 0; col < cols_per_thread; ++col) {
+#pragma unroll
+ for (int offset = 2; offset >= 1; offset >>= 1) {
+ KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
+ }
+ }
+
+ static_assert(c::nbatch_fa % (np*tile_C_KQ_16::J) == 0, "bad loop size");
+#pragma unroll
+ for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) {
+#pragma unroll
+ for (int t = 0; t < ntiles/2; ++t) {
+#pragma unroll
+ for (int l = 0; l < tile_C_KQ_16::ne; ++l) {
+ const int KQ_index = 2*t + (l/2) % 2;
+
+ KQ_C_16[k*ntiles/2 + t].x[l] = expf(KQ_C_16[k*ntiles/2 + t].x[l] - KQ_max_new[KQ_index]);
+
+ KQ_rowsum_add[KQ_index] += KQ_C_16[k*ntiles/2 + t].x[l];
+ }
+ }
+ }
+ }
+
+ {
+ float KQ_max_scale[cols_per_thread];
+#pragma unroll
+ for (int col = 0; col < cols_per_thread; ++col) {
+ KQ_max_scale[col] = expf(KQ_max[col] - KQ_max_new[col]);
+ KQ_max[col] = KQ_max_new[col];
+
+ // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
+ KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
+ }
+
+ if (ntiles == 1) {
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
+#pragma unroll
+ for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
+#pragma unroll
+ for (int l = 0; l < tile_C_VKQ::ne; ++l) {
+ VKQ_C[i].x[l] *= KQ_max_scale_h2;
+ }
+ }
+ } else {
+#pragma unroll
+ for (int col = 0; col < cols_per_thread; ++col) {
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
+#pragma unroll
+ for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) {
+#pragma unroll
+ for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
+ VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
+ }
+ }
+ }
+ }
+ }
+
+ // Convert KQ C tiles into B tiles for VKQ calculation:
+ tile_B B[c::nbatch_fa/(np*2*tile_B::J) * ntiles];
+ tile_B_16 * B_16 = (tile_B_16 *) B;
+ static_assert(c::nbatch_fa % (np*2*tile_B::J) == 0, "bad loop size");
+ if (ntiles == 1) {
+#pragma unroll
+ for (int k = 0; k < c::nbatch_fa/(np*2*tile_B::J); ++k) {
+ B[k] = get_transposed(get_half2(KQ_C[k]));
+ }
+ } else {
+ for (int k = 0; k < c::nbatch_fa/(np*2*tile_B_16::J); ++k) {
+#pragma unroll
+ for (int t = 0; t < ntiles/2; ++t) {
+ B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]);
+ }
+ }
+ }
+
+ if (nstages > 1) {
+ // Preload K tile for next iteration:
+ constexpr bool use_cp_async = true;
+ cp_async_wait_all();
+ __syncthreads();
+ if (!last_iter) {
+ if (ncols2 > 1 || mask_h2) {
+ flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
+ (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
+ }
+ flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
+ (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, c::nbatch_K2, stride_K);
+ }
+ }
+
+#pragma unroll
+ for (int i0_start = 0; i0_start < DV; i0_start += 2*c::nbatch_V2) {
+ const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV;
+ const int i0_diff = i0_stop - i0_start;
+
+ if (nstages == 1) {
+ constexpr bool use_cp_async = nstages == 1;
+ flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
+ (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
+ if (use_cp_async) {
+ cp_async_wait_all();
+ }
+ __syncthreads();
+ }
+
+ // Calculate VKQ tile:
+#pragma unroll
+ for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += tile_C_VKQ::I) {
+ static_assert((c::nbatch_fa/2) % (np*tile_A::J) == 0, "bad loop size");
+#pragma unroll
+ for (int k00 = 0; k00 < c::nbatch_fa/2; k00 += np*tile_A::J) {
+ const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
+
+ tile_A A;
+ load_ldmatrix_trans(A, tile_V + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
+ if (ntiles == 1) {
+ mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
+ } else {
+#pragma unroll
+ for (int t = 0; t < ntiles/2; ++t) {
+ // Wide version of VKQ_C is column-major => swap A and B.
+ mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A);
+ }
+ }
+ }
+ }
+
+ if (nstages <= 1) {
+ __syncthreads(); // Only needed if tile_K == tile_V.
+ }
+ }
+#else
+ GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
+ GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
+ GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
+ GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V);
+ GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
+ GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
+ GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
+ GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
+ GGML_UNUSED(kb0);
+ NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
+}
+
+template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
+static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
+ const float2 * const __restrict__ Q_f2,
+ const half2 * const __restrict__ K_h2,
+ const half2 * const __restrict__ V_h2,
+ const half2 * const __restrict__ mask_h2,
+ float2 * const __restrict__ dstk,
+ float2 * const __restrict__ dstk_fixup,
+ const float scale,
+ const float slope,
+ const float logit_softcap,
+ const int ne01,
+ const int ne02,
+ const int stride_Q1,
+ const int stride_Q2,
+ const int stride_K,
+ const int stride_V,
+ const int stride_mask,
+ const int jt,
+ const int kb0_start,
+ const int kb0_stop) {
+#ifdef INT8_MMA_AVAILABLE
+ //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+ typedef fattn_mma_f16_config<DKQ, DV> c;
+
+#ifdef CP_ASYNC_AVAILABLE
+ constexpr int nstages = c::nstages_target;
+#else
+ constexpr int nstages = 0;
+#endif // CP_ASYNC_AVAILABLE
+
+ constexpr int ncols = ncols1 * ncols2;
+ constexpr int cols_per_warp = ntiles * tile_B::I;
+ constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
+ constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+
+ static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
+
+ constexpr int stride_tile_Q = DKQ/2 + 4;
+ constexpr int stride_tile_K = c::nbatch_K2 + 4;
+ constexpr int stride_tile_V = c::nbatch_V2 + 4;
+
+ constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
+
+ extern __shared__ half2 tile_Q[];
+ half2 * tile_K = c::Q_in_reg ? tile_Q : tile_Q + ncols * stride_tile_Q;
+ half2 * tile_V = nstages > 1 ? tile_K + c::nbatch_fa * stride_tile_K : tile_K;
+ half2 * tile_mask = nstages > 1 ? tile_V + c::nbatch_fa * stride_tile_V : tile_V + c::nbatch_fa * stride_tile_KV_max;
+
+ tile_B Q_B[(c::Q_in_reg ? DKQ/(2*tile_B::J) : 1) * ntiles];
+ tile_C_VKQ VKQ_C[DV/tile_C_VKQ::I * ntiles];
+
+ tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B;
+ tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
+
+ float KQ_rowsum[cols_per_thread] = {0.0f};
+ float KQ_max[cols_per_thread];
+#pragma unroll
+ for (int col = 0; col < cols_per_thread; ++col) {
+ KQ_max[col] = -FLT_MAX/2.0f;
+ }
+
+ // Load Q data into tile_Q, either temporarily or permanently.
+ // Q in registers is faster, but register pressure is the biggest bottleneck.
+ // The loading is done with decreasing granularity for D for better memory bandwidth.
+ const half2 scale_h2 = make_half2(scale, scale);
+#pragma unroll
+ for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
+ const int k0_start = stride_k == WARP_SIZE ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k);
+ const int k0_stop = DKQ/2 - (DKQ/2) % (1*stride_k);
+ const int stride_jc = WARP_SIZE / stride_k;
+
+ if (k0_start == k0_stop) {
+ continue;
+ }
+
+#pragma unroll
+ for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) {
+ const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+
+ if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) {
+ break;
+ }
+
+ const int j = jc / ncols2;
+ const int c = jc % ncols2;
+
+ if (jt*ncols1 + j < ne01) {
+#pragma unroll
+ for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
+ const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+
+ const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k];
+ tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y);
+ }
+ } else {
+#pragma unroll
+ for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
+ const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+
+ tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f);
+ }
+ }
+ }
+ }
+
+ __syncthreads();
+
+ if (c::Q_in_reg) {
+ const int j0 = (threadIdx.y / np) * cols_per_warp;
+
+#pragma unroll
+ for (int k0 = 0; k0 < DKQ/2; k0 += tile_B::J) {
+ if (ntiles == 1) {
+ load_ldmatrix(Q_B[k0/tile_B::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q);
+ } else {
+#pragma unroll
+ for (int t = 0; t < ntiles/2; ++t) {
+ load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t],
+ tile_Q + (j0 + t*tile_B_16::I)*stride_tile_Q + k0, stride_tile_Q);
+ }
+ }
+ }
+ }
+
+ __syncthreads();
+
+ // Preload mask and K data for first iteration when using cp_async with multiple stages:
+ if constexpr (nstages > 1) {
+ static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
+ constexpr bool use_cp_async = true;
+ if (ncols2 > 1 || mask_h2) {
+ flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
+ (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
+ }
+ flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
+ (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, c::nbatch_K2, stride_K);
+ }
+
+ // Iterate over ne11 == previous tokens:
+ for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
+ constexpr bool last_iter = false;
+ flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
+ (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
+ ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
+ }
+ { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
+ constexpr bool last_iter = true;
+ flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
+ (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
+ ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
+ }
+
+ // With multi-stage loading there is no __syncthreads at the end of the iter,
+ // there can be a race condition on shared memory access for combining/writing back results.
+ if (nstages > 1 && nwarps*cols_per_warp > c::nbatch_fa) {
+ __syncthreads();
+ }
+
+ // Finally, sum up partial KQ rowsums.
+ // The partial sums are spread across 8/4 threads each, does not need full reduce.
+ {
+ constexpr int offset_first = ntiles == 1 ? 16 : 2;
+ constexpr int offset_last = ntiles == 1 ? 4 : 1;
+#pragma unroll
+ for (int col = 0; col < cols_per_thread; ++col) {
+#pragma unroll
+ for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
+ KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE);
+ }
+ }
+ }
+
+ // Combine VKQ accumulator values if np > 1.
+ // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
+ // So also write VKQ accumulators to shared memory in column-major format if np == 1.
+
+ constexpr int nbatch_combine = c::Q_in_reg ? DV/2 : DV/4;
+ constexpr int tile_stride = nbatch_combine + 4;
+ static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine");
+
+ if constexpr (ntiles == 1) {
+ const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset
+ const int jc_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta
+ const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); // KQ combine max rowsum
+
+ if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) {
+ // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
+ ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr;
+ }
+
+ __syncthreads();
+
+ if (np == 1) {
+ // No combination is needed, the meta data can be directly written from registers to VRAM.
+ if (needs_fixup && threadIdx.x < tile_B::I) {
+ float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
+ dstk_fixup_meta[jc_cwm] = KQ_cmr;
+ }
+ if (is_fixup && threadIdx.x < tile_B::I) {
+ float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
+ dstk_fixup_meta[jc_cwm] = KQ_cmr;
+ }
+ }
+ } else {
+ static_assert(ntiles == 2 || ntiles == 4, "bad ntiles");
+ const int jc_cwm = threadIdx.y*cols_per_warp // jc combine write meta
+ + (ntiles == 4 ? ((threadIdx.x % 4) / 2) * tile_C_VKQ_16::I : 0)
+ + tile_C_VKQ_16::get_i(threadIdx.x % 4);
+ const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); // KQ combine max rowsum
+
+ if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) {
+ // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
+ ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr;
+ }
+
+ __syncthreads();
+
+ if (np == 1) {
+ // No combination is needed, the meta data can be directly written from registers to VRAM.
+ if (needs_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) {
+ float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
+ dstk_fixup_meta[jc_cwm] = KQ_cmr;
+ }
+ if (is_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) {
+ float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
+ dstk_fixup_meta[jc_cwm] = KQ_cmr;
+ }
+ }
+ }
+
+ static_assert(np == 1 || ntiles == 1 || ntiles == 2, "bad ntiles");
+ if (np > 1 && threadIdx.y % np == 0) {
+ // Combine the meta data for parallel warps via shared memory.
+ // Warps with threadIdx.y % np != 0 must NOT return early.
+ // All threads must return simultaneously to avoid race conditions with work on the next tile.
+
+ constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1;
+
+ const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
+ float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2;
+ float2 meta[nmeta];
+#pragma unroll
+ for (int imeta = 0; imeta < nmeta; ++imeta) {
+ meta[imeta] = meta_ptr[imeta * WARP_SIZE * tile_stride/2];
+ }
+
+ float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps.
+#pragma unroll
+ for (int imeta = 1; imeta < nmeta; ++imeta) {
+ KQ_cmn = fmaxf(KQ_cmn, meta[imeta].x);
+ }
+#pragma unroll
+ for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
+ if (offset >= WARP_SIZE) {
+ continue;
+ }
+ KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
+ }
+
+ float KQ_cms[nmeta]; // KQ combine max scale per warp.
+#pragma unroll
+ for (int imeta = 0; imeta < nmeta; ++imeta) {
+ KQ_cms[imeta] = expf(meta[imeta].x - KQ_cmn);
+ }
+
+ float KQ_crs = KQ_cms[0]*meta[0].y; // KQ combine rowsum, scaled sum of all parallel warps.
+#pragma unroll
+ for (int imeta = 1; imeta < nmeta; ++imeta) {
+ KQ_crs += KQ_cms[imeta]*meta[imeta].y;
+ }
+#pragma unroll
+ for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
+ if (offset >= WARP_SIZE) {
+ continue;
+ }
+ KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
+ }
+
+ // Write back combined meta data:
+#pragma unroll
+ for (int imeta = 0; imeta < nmeta; ++imeta) {
+ if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) {
+ // Combined KQ max scale + rowsum.
+ meta_ptr[imeta * WARP_SIZE * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
+ }
+ }
+
+ // Combined KQ max + rowsum.
+ static_assert(cols_per_warp <= WARP_SIZE);
+ if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
+ float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
+ dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
+ }
+ if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
+ float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
+ dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
+ }
+ }
+
+#pragma unroll
+ for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) {
+ if (ntiles == 1) {
+ const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data
+#pragma unroll
+ for (int k0 = 0; k0 < nbatch_combine; k0 += tile_B::J) {
+ const tile_B B = get_transposed(VKQ_C[(k00 + k0)/tile_B::J]); // Conversion of C to B matrix puts it in column-major format.
+
+#pragma unroll
+ for (int l = 0; l < tile_B::ne; ++l) {
+ const int k = k0 + tile_B::get_j(l);
+
+ tile_Q[jc_cwd*tile_stride + k] = B.x[l];
+ }
+ }
+ } else {
+#pragma unroll
+ for (int t = 0; t < ntiles/2; ++t) {
+ const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I;
+#pragma unroll
+ for (int k0 = 0; k0 < nbatch_combine; k0 += tile_C_VKQ_16::J) {
+#pragma unroll
+ for (int l = 0; l < tile_C_VKQ_16::ne; ++l) {
+ const int j = j0 + tile_C_VKQ_16::get_i(l);
+ const int k = k0 + tile_C_VKQ_16::get_j(l);
+
+ tile_Q[j*tile_stride + k] = VKQ_C_16[(k00 + k0)/tile_C_VKQ_16::J * ntiles/2 + t].x[l];
+ }
+ }
+ }
+ }
+
+ __syncthreads();
+
+ if (np == 1 || threadIdx.y % np == 0) {
+ // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums.
+ // The values after that are for the partial results of the individual blocks.
+ float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2));
+
+#pragma unroll
+ for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
+ const int k0_start = stride_k == WARP_SIZE ? 0 : nbatch_combine - nbatch_combine % (2*stride_k);
+ const int k0_stop = nbatch_combine - nbatch_combine % (1*stride_k);
+ const int stride_jc = WARP_SIZE / stride_k;
+
+ if (k0_start == k0_stop) {
+ continue;
+ }
+
+#pragma unroll
+ for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) {
+ const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+
+ if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) {
+ break;
+ }
+
+ const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp;
+
+ const int j_dst = jc_dst / ncols2;
+ const int c_dst = jc_dst % ncols2;
+
+ if (!is_fixup && jt*ncols1 + j_dst >= ne01) {
+ continue;
+ }
+
+ const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine;
+#pragma unroll
+ for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
+ const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+
+ float2 dstk_val = make_float2(0.0f, 0.0f);
+#pragma unroll
+ for (int ip = 0; ip < np; ++ip) {
+ const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * tile_stride + 0];
+ const float2 dstk_val_add = __half22float2(tile_Q[(jc_tile_K + ip*cols_per_warp) * tile_stride + k]);
+ dstk_val.x += dstk_val_add.x*KQ_crs;
+ dstk_val.y += dstk_val_add.y*KQ_crs;
+ }
+
+ if (!needs_fixup && !is_fixup) {
+ const float KQ_rowsum_j = meta_j[1];
+ dstk_val.x /= KQ_rowsum_j;
+ dstk_val.y /= KQ_rowsum_j;
+ }
+
+ if (is_fixup) {
+ dstk_fixup_data[jc_dst*(DV/2) + k00 + k] = dstk_val;
+ } else {
+ dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(DV/2) + k00 + k] = dstk_val;
+ }
+ }
+ }
+ }
+ }
+ if (np > 1) {
+ __syncthreads();
+ }
+ }
+#else
+ GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
+ GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
+ GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
+ GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1);
+ GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask);
+ GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop);
+ NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
+}
+
+template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap>
+__launch_bounds__(nwarps*WARP_SIZE, 1)
+static __global__ void flash_attn_ext_f16(
+ const char * __restrict__ Q,
+ const char * __restrict__ K,
+ const char * __restrict__ V,
+ const char * __restrict__ mask,
+ float * __restrict__ dst,
+ float2 * __restrict__ dst_meta,
+ const float scale,
+ const float max_bias,
+ const float m0,
+ const float m1,
+ const float logit_softcap,
+ const uint32_t n_head_log2,
+ const int ne00,
+ const int ne01,
+ const int ne02,
+ const int ne03,
+ const int ne10,
+ const int ne11,
+ const int ne12,
+ const int ne13,
+ const int ne31,
+ const int nb31,
+ const int nb01,
+ const int nb02,
+ const int nb03,
+ const int nb11,
+ const int nb12,
+ const int nb13,
+ const int nb21,
+ const int nb22,
+ const int nb23,
+ const int ne0,
+ const int ne1,
+ const int ne2,
+ const int ne3) {
+#if defined(INT8_MMA_AVAILABLE)
+
+ // Skip unused kernel variants for faster compilation:
+ if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
+ NO_DEVICE_CODE;
+ return;
+ }
+
+ typedef fattn_mma_f16_config<DKQ, DV> c;
+
+ static_assert(FATTN_KQ_STRIDE % fattn_mma_f16_config<DKQ, DV>::nbatch_fa == 0, "bad nbatch_fa");
+
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+
+ const int stride_Q1 = nb01 / sizeof(float2);
+ const int stride_Q2 = nb02 / sizeof(float2);
+ const int stride_K = nb11 / sizeof(half2);
+ const int stride_V = nb21 / sizeof(half2);
+ const int stride_mask = nb31 / sizeof(half2);
+
+ const int iter_k = ne11 / FATTN_KQ_STRIDE;
+ const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
+
+ constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
+
+ // kbc == k block continuous, current index in continuous ijk space.
+ int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
+ const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
+
+ // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
+ // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
+ // In the most general case >2 seams can fall into the same tile.
+
+ // kb0 == k start index when in the output tile.
+ int kb0_start = kbc % iter_k;
+ int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
+ while (kbc < kbc_stop && kb0_stop == iter_k) {
+ const int channel = kbc / (iter_k*iter_j);
+ const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
+
+ const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
+ const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
+ const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
+ const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
+ float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
+
+ const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
+
+ const int kb0_start_kernel = kb0_start * kb_niter;
+ const int kb0_stop_kernel = kb0_stop * kb_niter;
+
+ constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
+ if (kb0_start == 0) {
+ constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
+ (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
+ ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
+ } else {
+ constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
+ (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
+ ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
+ }
+
+ kbc += iter_k;
+ kbc -= kbc % iter_k;
+
+ kb0_start = 0;
+ kb0_stop = min(iter_k, kbc_stop - kbc);
+ }
+
+ if (kbc >= kbc_stop) {
+ return;
+ }
+
+ const int channel = kbc / (iter_k*iter_j);
+ const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
+
+ const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
+ const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
+ const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); // K and V have same shape
+ const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
+ float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
+
+ const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
+
+ const int kb0_start_kernel = kb0_start * kb_niter;
+ const int kb0_stop_kernel = kb0_stop * kb_niter;
+
+ constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
+ constexpr bool needs_fixup = false;
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
+ (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
+ ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
+#else
+ GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+ GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
+ GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
+ GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
+ GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
+ GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
+ GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
+ GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
+ GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
+ GGML_UNUSED(ne2); GGML_UNUSED(ne3);
+ NO_DEVICE_CODE;
+#endif // defined(INT8_MMA_AVAILABLE)
+}
+
+template<int D, int ncols1, int ncols2> // D == head size
+__launch_bounds__(D, 1)
+static __global__ void flash_attn_stream_k_fixup(
+ float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
+ constexpr int ncols = ncols1*ncols2;
+
+ const int bidx0 = blockIdx.x;
+ const int j = blockIdx.y;
+ const int c = blockIdx.z;
+ const int jc = j*ncols2 + c;
+ const int tid = threadIdx.x;
+
+ const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
+
+ const int iter_k = ne11 / FATTN_KQ_STRIDE;
+ const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
+
+ const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
+ const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
+
+ const bool did_not_have_any_data = kbc0 == kbc0_stop;
+ const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
+ const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0;
+ if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
+ return;
+ }
+
+ const int channel = kbc0 / (iter_k*iter_j);
+ const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
+
+ if (jt*ncols1 + j >= ne01) {
+ return;
+ }
+
+ dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
+
+ // Load the partial result that needs a fixup:
+ float dst_val = 0.0f;
+ float max_val = 0.0f;
+ float rowsum = 0.0f;
+ {
+ dst_val = *dst;
+
+ const float2 tmp = dst_fixup[bidx0*ncols + jc];
+ max_val = tmp.x;
+ rowsum = tmp.y;
+ }
+
+ // Iterate over previous blocks and compute the combined results.
+ // All CUDA blocks that get here must have a previous block that needs a fixup.
+ int bidx = bidx0 - 1;
+ int kbc_stop = kbc0;
+ while(true) {
+ const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
+ if (kbc == kbc_stop) { // Did not have any data.
+ bidx--;
+ kbc_stop = kbc;
+ continue;
+ }
+
+ const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
+
+ const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc];
+
+ // Scale the current and new value accumulators depending on the max. values.
+ const float max_val_new = fmaxf(max_val, tmp.x);
+
+ const float diff_val = max_val - max_val_new;
+ const float diff_add = tmp.x - max_val_new;
+
+ const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
+ const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
+
+ dst_val = scale_val*dst_val + scale_add*dst_add;
+ rowsum = scale_val*rowsum + scale_add*tmp.y;
+
+ max_val = max_val_new;
+
+ // If this block started in a previous tile we are done and don't need to combine additional partial results.
+ if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
+ break;
+ }
+ bidx--;
+ kbc_stop = kbc;
+ }
+
+ // Write back final result:
+ *dst = dst_val / rowsum;
+}
+
+template<int D> // D == head size
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(D, 1)
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+static __global__ void flash_attn_combine_results_new(
+ const float * __restrict__ VKQ_parts,
+ const float2 * __restrict__ VKQ_meta,
+ float * __restrict__ dst,
+ const int parallel_blocks) {
+ VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x;
+ VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x;
+ dst += D * gridDim.z*blockIdx.x;
+
+ const int tid = threadIdx.x;
+ __builtin_assume(tid < D);
+
+ extern __shared__ float2 meta[];
+ if (tid < 2*parallel_blocks) {
+ ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid];
+ }
+
+ __syncthreads();
+
+ float kqmax = meta[0].x;
+ for (int l = 1; l < parallel_blocks; ++l) {
+ kqmax = max(kqmax, meta[l].x);
+ }
+
+ float VKQ_numerator = 0.0f;
+ float VKQ_denominator = 0.0f;
+ for (int l = 0; l < parallel_blocks; ++l) {
+ const float diff = meta[l].x - kqmax;
+ float KQ_max_scale = expf(diff);
+ const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
+ *((uint32_t *) &KQ_max_scale) &= ftz_mask;
+
+ VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid];
+ VKQ_denominator += KQ_max_scale * meta[l].y;
+ }
+
+ dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
+}
+
+template <int DV, int ncols1, int ncols2>
+void launch_fattn_new_mma(
+ ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
+ const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
+) {
+ constexpr int ncols = ncols1 * ncols2;
+
+ const ggml_tensor * Q = dst->src[0];
+ const ggml_tensor * K = dst->src[1];
+ const ggml_tensor * V = dst->src[2];
+
+ const ggml_tensor * mask = dst->src[3];
+
+ ggml_tensor * KQV = dst;
+
+ GGML_ASSERT(Q->type == GGML_TYPE_F32);
+ GGML_ASSERT(KQV->type == GGML_TYPE_F32);
+
+ GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
+ GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
+ "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
+
+ GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
+
+ GGML_ASSERT(Q->ne[3] == 1);
+
+ ggml_cuda_pool & pool = ctx.pool();
+ cudaStream_t main_stream = ctx.stream();
+ const int id = ggml_cuda_get_device();
+ const int cc = ggml_cuda_info().devices[id].cc;
+ const int nsm = ggml_cuda_info().devices[id].nsm;
+
+ ggml_cuda_pool_alloc<half> K_f16(pool);
+ ggml_cuda_pool_alloc<half> V_f16(pool);
+ ggml_cuda_pool_alloc<float> dst_tmp(pool);
+ ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
+
+ const char * K_data = (const char *) K->data;
+ size_t nb11 = K->nb[1];
+ size_t nb12 = K->nb[2];
+ size_t nb13 = K->nb[3];
+
+ const char * V_data = (const char *) V->data;
+ size_t nb21 = V->nb[1];
+ size_t nb22 = V->nb[2];
+ size_t nb23 = V->nb[3];
+
+ if (need_f16_K && K->type != GGML_TYPE_F16) {
+ K_f16.alloc(ggml_nelements(K));
+ to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
+ to_fp16(K_data, K_f16.ptr, 1, ggml_nelements(K), main_stream);
+ K_data = (char *) K_f16.ptr;
+
+ const size_t bs = ggml_blck_size(K->type);
+ const size_t ts = ggml_type_size(K->type);
+
+ nb11 = nb11*bs*sizeof(half)/ts;
+ nb12 = nb12*bs*sizeof(half)/ts;
+ nb13 = nb13*bs*sizeof(half)/ts;
+ }
+
+ if (need_f16_V && V->type != GGML_TYPE_F16) {
+ V_f16.alloc(ggml_nelements(V));
+ to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
+ to_fp16(V_data, V_f16.ptr, 1, ggml_nelements(V), main_stream);
+ V_data = (char *) V_f16.ptr;
+
+ const size_t bs = ggml_blck_size(V->type);
+ const size_t ts = ggml_type_size(V->type);
+
+ nb21 = nb21*bs*sizeof(half)/ts;
+ nb22 = nb22*bs*sizeof(half)/ts;
+ nb23 = nb23*bs*sizeof(half)/ts;
+ }
+
+ int parallel_blocks = 1;
+
+ const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
+ const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
+
+ const dim3 block_dim(warp_size, nwarps, 1);
+ int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
+ CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
+
+ dim3 blocks_num;
+ if (stream_k) {
+ // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
+ const int max_blocks = max_blocks_per_sm*nsm;
+ const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
+ const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
+
+ const int nblocks_stream_k = max_blocks;
+
+ const bool use_stream_k = cc >= CC_ADA_LOVELACE || tiles_efficiency_percent < 75;
+
+ blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
+ blocks_num.y = 1;
+ blocks_num.z = 1;
+
+ dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
+ } else {
+ GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
+ const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
+
+ // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
+ parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
+
+ // parallel_blocks must not be larger than what the tensor size allows:
+ parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
+
+ // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects.
+ // Test whether parallel_blocks can be set to a higher value for better efficiency.
+ const int blocks_per_wave = nsm * max_blocks_per_sm;
+ int nwaves_best = 0;
+ int efficiency_percent_best = 0;
+ for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) {
+ const int nblocks_total = ntiles_total * parallel_blocks_test;
+ const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave;
+ const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
+
+ // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead.
+ if (efficiency_percent_best >= 90 && nwaves > nwaves_best) {
+ break;
+ }
+
+ if (efficiency_percent > efficiency_percent_best) {
+ nwaves_best = nwaves;
+ efficiency_percent_best = efficiency_percent;
+ parallel_blocks = parallel_blocks_test;
+ }
+ }
+
+ blocks_num.x = ntiles_x;
+ blocks_num.y = parallel_blocks;
+ blocks_num.z = Q->ne[2]*Q->ne[3];
+
+ if (parallel_blocks > 1) {
+ dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
+ dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
+ }
+ }
+ float scale = 1.0f;
+ float max_bias = 0.0f;
+ float logit_softcap = 0.0f;
+
+ memcpy(&scale, (const float *) KQV->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
+ memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+ if (logit_softcap != 0.0f) {
+ scale /= logit_softcap;
+ }
+
+ const uint32_t n_head = Q->ne[2];
+ const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head))));
+
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+ GGML_ASSERT(block_dim.x % warp_size == 0);
+ fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
+ (const char *) Q->data,
+ K_data,
+ V_data,
+ mask ? ((const char *) mask->data) : nullptr,
+ !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
+ scale, max_bias, m0, m1, n_head_log2, logit_softcap,
+ Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
+ K->ne[0], K->ne[1], K->ne[2], K->ne[3],
+ mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
+ Q->nb[1], Q->nb[2], Q->nb[3],
+ nb11, nb12, nb13,
+ nb21, nb22, nb23,
+ KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
+ );
+ CUDA_CHECK(cudaGetLastError());
+
+ if (stream_k) {
+ if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
+ const dim3 block_dim_combine(DV, 1, 1);
+ const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
+
+ flash_attn_stream_k_fixup<DV, ncols1, ncols2>
+ <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
+ ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
+ }
+ } else if (parallel_blocks > 1) {
+ const dim3 block_dim_combine(DV, 1, 1);
+ const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
+ const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
+
+ flash_attn_combine_results_new<DV>
+ <<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
+ (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
+ }
+ CUDA_CHECK(cudaGetLastError());
+}
+
+
+template <int DKQ, int DV, int ncols1, int ncols2>
+void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * KQV = dst;
+ const int id = ggml_cuda_get_device();
+ const int cc = ggml_cuda_info().devices[id].cc;
+
+ typedef fattn_mma_f16_config<DKQ, DV> c;
+
+ constexpr int nbatch_K2 = c::nbatch_K2 < 1 ? DKQ/2 : c::nbatch_K2;
+ constexpr int nbatch_V2 = c::nbatch_V2 < 1 ? DV /2 : c::nbatch_V2;
+ constexpr int nbatch_combine = c::nbatch_combine < 1 ? DV /2 : c::nbatch_combine;
+
+ const int nstages = cp_async_available(cc) ? c::nstages_target : 0;
+
+ constexpr int ncols = ncols1 * ncols2;
+ constexpr int ntiles = ncols <= 8 ? 1 : 2; // Number of tiles per warp.
+ constexpr int cols_per_warp = ntiles * tile_B::I;
+ constexpr int nwarps_max_x = ncols / cols_per_warp;
+ constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I;
+ constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max;
+
+ static_assert(DKQ % tile_B::J == 0, "bad DKQ");
+ static_assert(DV % tile_A::J == 0, "bad DV");
+ static_assert(ncols % cols_per_warp == 0, "bad ncols");
+
+ const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(c::nbatch_K2 + 4, c::nbatch_V2 + 4) * sizeof(half2);
+ const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (c::nbatch_K2 + 4 + c::nbatch_V2 + 4) * sizeof(half2);
+ const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
+ const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2);
+ const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
+
+ const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;
+
+ const size_t nbytes_shared_total = std::max(nbytes_shared_combine, c::Q_in_reg ?
+ std::max(nbytes_shared_Q, nbytes_shared_KV + nbytes_shared_mask) :
+ nbytes_shared_Q + nbytes_shared_KV + nbytes_shared_mask);
+
+ float logit_softcap;
+ memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+ fattn_kernel_t fattn_kernel;
+ if (logit_softcap == 0.0f) {
+ constexpr bool use_logit_softcap = false;
+ fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
+
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
+ static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
+ if (!shared_memory_limit_raised[id]) {
+ CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
+ shared_memory_limit_raised[id] = true;
+ }
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
+ } else {
+ constexpr bool use_logit_softcap = true;
+ fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
+
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
+ static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
+ if (!shared_memory_limit_raised[id]) {
+ CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
+ shared_memory_limit_raised[id] = true;
+ }
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
+ }
+
+ launch_fattn_new_mma<DV, ncols1, ncols2>
+ (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true);
+}
+
+template <int DKQ, int DV, int ncols2>
+static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * Q = dst->src[0];
+
+ if constexpr (ncols2 <= 8) {
+ if (Q->ne[1] <= 8/ncols2) {
+ ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 8/ncols2, ncols2>(ctx, dst);
+ return;
+ }
+ }
+
+ if (Q->ne[1] <= 16/ncols2) {
+ ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
+ return;
+ }
+
+ if (Q->ne[1] <= 32/ncols2) {
+ ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
+ return;
+ }
+
+ ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 64/ncols2, ncols2>(ctx, dst);
+}
+
+template <int DKQ, int DV>
+static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * KQV = dst;
+ const ggml_tensor * Q = dst->src[0];
+ const ggml_tensor * K = dst->src[1];
+ const ggml_tensor * mask = dst->src[3];
+
+ float max_bias = 0.0f;
+ memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
+
+ const bool use_gqa_opt = mask && max_bias == 0.0f;
+
+ GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
+ const int gqa_ratio = Q->ne[2] / K->ne[2];
+
+ if (use_gqa_opt && gqa_ratio % 8 == 0) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
+ return;
+ }
+
+ if (use_gqa_opt && gqa_ratio % 4 == 0) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst);
+ return;
+ }
+
+ if (use_gqa_opt && gqa_ratio % 2 == 0) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
+ return;
+ }
+
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
+}
+
+void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * KQV = dst;
+ const ggml_tensor * Q = dst->src[0];
+ const ggml_tensor * K = dst->src[1];
+ const ggml_tensor * V = dst->src[2];
+ const ggml_tensor * mask = dst->src[3];
+
+ GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512);
+
+ float max_bias = 0.0f;
+ memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
+
+ const bool use_gqa_opt = mask && max_bias == 0.0f;
+ GGML_ASSERT(use_gqa_opt);
+
+ GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
+ const int gqa_ratio = Q->ne[2] / K->ne[2];
+ GGML_ASSERT(gqa_ratio % 16 == 0);
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+
+ //switch (Q->ne[0]) {
+ // case 64:
+ // GGML_ASSERT(V->ne[0] == 64);
+ // ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 64, 64>(ctx, dst);
+ // break;
+ // case 80:
+ // GGML_ASSERT(V->ne[0] == 80);
+ // ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 80, 80>(ctx, dst);
+ // break;
+ // case 96:
+ // GGML_ASSERT(V->ne[0] == 96);
+ // ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 96, 96>(ctx, dst);
+ // break;
+ // case 112:
+ // GGML_ASSERT(V->ne[0] == 112);
+ // ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<112, 112>(ctx, dst);
+ // break;
+ // case 128:
+ // GGML_ASSERT(V->ne[0] == 128);
+ // ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst);
+ // break;
+ // case 192:
+ // GGML_ASSERT(V->ne[0] == 128);
+ // ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<192, 128>(ctx, dst);
+ // break;
+ // case 256:
+ // GGML_ASSERT(V->ne[0] == 256);
+ // ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
+ // break;
+ // case 576: {
+ // // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
+ // GGML_ASSERT(V->ne[0] == 512);
+ // float max_bias = 0.0f;
+ // memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
+
+ // const bool use_gqa_opt = mask && max_bias == 0.0f;
+ // GGML_ASSERT(use_gqa_opt);
+
+ // GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
+ // const int gqa_ratio = Q->ne[2] / K->ne[2];
+ // GGML_ASSERT(gqa_ratio % 16 == 0);
+ // ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+ // } break;
+ // default:
+ // GGML_ABORT("fatal error");
+ // break;
+ //}
+}
+
diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cuh b/ggml/src/ggml-cuda/fattn-new-mma.cuh
new file mode 100644
index 00000000..40f867df
--- /dev/null
+++ b/ggml/src/ggml-cuda/fattn-new-mma.cuh
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
index ea52fa02..725b443d 100644
--- a/ggml/src/ggml-cuda/fattn.cu
+++ b/ggml/src/ggml-cuda/fattn.cu
@@ -13,6 +13,7 @@
#include "fattn-vec-f32.cuh"
#include "fattn-wmma-f16.cuh"
#include "fattn-mma-f16.cuh"
+#include "fattn-new-mma.cuh"
#include "fattn.cuh"
#include <cstdint>
@@ -517,12 +518,28 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
return;
}
- // We need this because I haven't adapted the MMA kernels to work for different
+ //
+ // It turns out the new new MMA implementation is slower than the
+ // previous MMA implementation.
+ // Hence, we use it only for DeepSeek with MLA enabled, where head sizes are 576, 512,
+ // so no other implementation works.
+ //
+ if (new_mma_available(cc) && Q->ne[0] == 576) {
+ ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
+ return;
+ }
+
+ //
+ // We need this because I haven't adapted new MMA kernels to work for different
// K and V head sizes.
- if (K->ne[0] != V->ne[0]) {
+ // We also need it if the new MMA is not available
+ //
+ if (!new_mma_available(cc) || K->ne[0] != V->ne[0]) {
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
return;
}
+ // As mentioned above, the new new MMA is slower than then the new MMA.
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
+ //ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
}
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
index 3d991b4d..f87ebb96 100644
--- a/ggml/src/ggml-cuda/mmvq.cu
+++ b/ggml/src/ggml-cuda/mmvq.cu
@@ -61,7 +61,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
}
}
-template <ggml_type type, int ncols_y>
+template <ggml_type type, int ncols_y, int nwarps>
static __device__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
@@ -73,10 +73,8 @@ static __device__ void mul_mat_vec_q(
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
- constexpr int nwarps = 1;
constexpr int rows_per_cuda_block = 1;
#else
- constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
@@ -139,10 +137,10 @@ static __device__ void mul_mat_vec_q(
}
}
-template <ggml_type type, int ncols_y>
+template <ggml_type type, int ncols_y, int nwarps>
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
// tell the compiler to use as many registers as it wants, see nwarps definition below
-__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
+__launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const char * __restrict__ ids_data,
@@ -153,11 +151,11 @@ static __global__ void mul_mat_vec_q(
const char * cx = (const char *)vx + i02*nb02;
const char * cy = (const char *)vy + i2*nb12;
char * cdst = (char *)dst + i2*nb2;
- mul_mat_vec_q<type, ncols_y>(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ mul_mat_vec_q<type, ncols_y, nwarps>(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
-template <ggml_type type>
-static void mul_mat_vec_q_cuda(
+template <ggml_type type, int nwarps>
+static void mul_mat_vec_q_cuda_T(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream) {
@@ -167,61 +165,61 @@ static void mul_mat_vec_q_cuda(
int id = ggml_cuda_get_device();
- int64_t nwarps = 1;
- int64_t rows_per_cuda_block = 1;
-
- if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
- switch(ncols_y) {
- case 1:
- nwarps = 4;
- rows_per_cuda_block = 1;
- break;
- case 2:
- case 3:
- case 4:
- nwarps = 4;
- rows_per_cuda_block = 2;
- break;
- case 5:
- case 6:
- case 7:
- case 8:
- nwarps = 2;
- rows_per_cuda_block = 2;
- break;
- default:
- GGML_ABORT("fatal error");
- break;
- }
- }
+ int64_t rows_per_cuda_block = ggml_cuda_info().devices[id].cc < CC_RDNA2 ?
+ ncols_y < 4 ? 1 : 2 : 1;
+
+ //if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
+ // switch(ncols_y) {
+ // case 1:
+ // nwarps = 4;
+ // rows_per_cuda_block = 1;
+ // break;
+ // case 2:
+ // case 3:
+ // case 4:
+ // nwarps = 4;
+ // rows_per_cuda_block = 2;
+ // break;
+ // case 5:
+ // case 6:
+ // case 7:
+ // case 8:
+ // nwarps = 2;
+ // rows_per_cuda_block = 2;
+ // break;
+ // default:
+ // GGML_ABORT("fatal error");
+ // break;
+ // }
+ //}
const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
const dim3 block_nums(nblocks, ne2, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
switch (ncols_y) {
case 1:
- mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
+ mul_mat_vec_q<type, 1, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
break;
case 2:
- mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
+ mul_mat_vec_q<type, 2, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
break;
case 3:
- mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
+ mul_mat_vec_q<type, 3, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
break;
case 4:
- mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
+ mul_mat_vec_q<type, 4, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
break;
case 5:
- mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
+ mul_mat_vec_q<type, 5, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
break;
case 6:
- mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
+ mul_mat_vec_q<type, 6, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
break;
case 7:
- mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
+ mul_mat_vec_q<type, 7, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
break;
case 8:
- mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
+ mul_mat_vec_q<type, 8, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
break;
default:
GGML_ABORT("fatal error");
@@ -229,6 +227,31 @@ static void mul_mat_vec_q_cuda(
}
}
+template <ggml_type type>
+static void mul_mat_vec_q_cuda(
+ const void * vx, const void * vy, float * dst, const char * ids_data,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream) {
+ int nwarps = 1;
+ int id = ggml_cuda_get_device();
+ if (ne2 < 2 && ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
+ nwarps = ncols_y <= 4 ? 4 : 2;
+ }
+ switch (nwarps) {
+ case 1:
+ mul_mat_vec_q_cuda_T<type, 1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst,
+ ne2, nb02, nb12, nb2, ids_nb0, stream);
+ break;
+ case 2:
+ mul_mat_vec_q_cuda_T<type, 2>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst,
+ ne2, nb02, nb12, nb2, ids_nb0, stream);
+ break;
+ default:
+ mul_mat_vec_q_cuda_T<type, 4>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst,
+ ne2, nb02, nb12, nb2, ids_nb0, stream);
+ }
+}
+
static void mul_mat_vec_q4_0_q8_1_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,