diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-05-04 09:17:44 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-05-04 09:17:44 +0300 |
commit | ce2b0292e18cd8dd87776797fa455e7fc4cfeed9 (patch) | |
tree | a0c3b0eaad8e4a111faf05c7d10f7c2493f2846d | |
parent | b890e01238dd73e38c8cac99e34eed2071e1c4c1 (diff) |
CUDA: faster FA TG for GQA models (#370)
* cuda: WIP MMA FA
* Use MMA for TG also when quantized
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
23 files changed, 2158 insertions, 11 deletions
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 17a9e2eb..74ac5374 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -321,6 +321,8 @@ if (GGML_CUDA) list(APPEND GGML_SOURCES_CUDA "ggml-cuda.cu") file(GLOB SRCS "ggml-cuda/template-instances/fattn-wmma*.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) + file(GLOB SRCS "ggml-cuda/template-instances/fattn-mma*.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) file(GLOB SRCS "ggml-cuda/template-instances/mmq*.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 2eba527f..0a7f7f83 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -46,10 +46,14 @@ #define CC_VOLTA 700 #define CC_TURING 750 #define CC_AMPERE 800 +#define CC_ADA_LOVELACE 890 #define CC_OFFSET_AMD 1000000 +#define CC_OFFSET_MTHREADS 0x0100000 #define CC_RDNA1 (CC_OFFSET_AMD + 1010) #define CC_RDNA2 (CC_OFFSET_AMD + 1030) #define CC_RDNA3 (CC_OFFSET_AMD + 1100) +#define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < CC_OFFSET_MTHREADS) +#define GGML_CUDA_CC_IS_AMD(cc) (cc >= CC_OFFSET_AMD) #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses @@ -134,6 +138,49 @@ typedef float2 dfloat2; #define INT8_MMA_AVAILABLE #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE +#define CP_ASYNC_AVAILABLE +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE + +#ifdef __CUDA_ARCH_LIST__ +constexpr bool ggml_cuda_has_arch_impl(int) { + return false; +} + +template<class ... Archs> +constexpr bool ggml_cuda_has_arch_impl(const int arch, const int first, Archs... rest) { + return arch == first || ggml_cuda_has_arch_impl(arch, rest...); +} + +constexpr bool ggml_cuda_has_arch(const int arch) { + return ggml_cuda_has_arch_impl(arch, __CUDA_ARCH_LIST__); +} + +constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur) { + if (cur == 0) { + GGML_ABORT("ggml was not compiled with any CUDA arch <= %d", arch); + } + return cur; +} + +template<class ... Archs> +constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur, const int first, Archs... rest) { + if (first <= arch && first > cur) { + return ggml_cuda_highest_compiled_arch_impl(arch, first, rest...); + } else { + return ggml_cuda_highest_compiled_arch_impl(arch, cur, rest...); + } +} + +constexpr int ggml_cuda_highest_compiled_arch(const int arch) { + return ggml_cuda_highest_compiled_arch_impl(arch, 0, __CUDA_ARCH_LIST__); +} +#else +static int ggml_cuda_highest_compiled_arch(const int arch) { + return arch; +} +#endif // __CUDA_ARCH_LIST__ + static constexpr bool fast_fp16_available(const int cc) { return cc >= CC_PASCAL && cc != 610; } @@ -146,6 +193,15 @@ static constexpr bool int8_mma_available(const int cc) { return cc < CC_OFFSET_AMD && cc >= CC_TURING; } +// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later. +static bool new_mma_available(const int cc) { + return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= CC_TURING; +} + +static bool cp_async_available(const int cc) { + return cc < CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= CC_AMPERE; +} + [[noreturn]] static __device__ void no_device_code( const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) { diff --git a/ggml/src/ggml-cuda/cp-async.cuh b/ggml/src/ggml-cuda/cp-async.cuh new file mode 100644 index 00000000..ecb65999 --- /dev/null +++ b/ggml/src/ggml-cuda/cp-async.cuh @@ -0,0 +1,46 @@ +// Simplified API for asynchronous data loading. + +#include "common.cuh" + +// Copies data from global to shared memory, cg == cache global. +// Both the src and dst pointers must be aligned to 16 bit. +// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int. +// Generic pointers can be converted to 32 bit shared memory pointers using __cvta_generic_to_shared. +// Only the 16 bit copy is exposed because 4 and 8 bit copies did not yield performance improvements. +template <int preload> +static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, const void * src) { + static_assert(preload == 0 || preload == 64 || preload == 128 || preload == 256, "bad preload"); +#ifdef CP_ASYNC_AVAILABLE +#if CUDART_VERSION >= 11040 + if (preload == 256) { + asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], 16;" + : : "r"(dst), "l"(src)); + } else if (preload == 128) { + asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], 16;" + : : "r"(dst), "l"(src)); + } else if (preload == 64) { + asm volatile("cp.async.cg.shared.global.L2::64B [%0], [%1], 16;" + : : "r"(dst), "l"(src)); + } else +#endif // CUDART_VERSION >= 11040 + { + asm volatile("cp.async.cg.shared.global [%0], [%1], 16;" + : : "r"(dst), "l"(src)); + } +#else + GGML_UNUSED(dst); + GGML_UNUSED(src); + NO_DEVICE_CODE; +#endif // CP_ASYNC_AVAILABLE +} + +// Makes each thread wait until its asynchronous data copies are done. +// This does NOT provide any additional synchronization. +// In particular, when copying data with multiple warps a call to __syncthreads will be needed. +static __device__ __forceinline__ void cp_async_wait_all() { +#ifdef CP_ASYNC_AVAILABLE + asm volatile("cp.async.wait_all;"); +#else + NO_DEVICE_CODE; +#endif // CP_ASYNC_AVAILABLE +} diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 08bb5cd3..5e8ee0f6 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -862,3 +862,336 @@ void launch_fattn( (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); CUDA_CHECK(cudaGetLastError()); } + +template<int D, int ncols1, int ncols2, int KQ_stride> // D == head size +__launch_bounds__(D, 1) +static __global__ void flash_attn_mma_stream_k_fixup( + float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) { + constexpr int ncols = ncols1*ncols2; + + const int bidx0 = blockIdx.x; + const int j = blockIdx.y; + const int c = blockIdx.z; + const int jc = j*ncols2 + c; + const int tid = threadIdx.x; + + const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols); + + const int iter_k = ne11 / FATTN_KQ_STRIDE; + const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; + + const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + + const bool did_not_have_any_data = kbc0 == kbc0_stop; + const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; + const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0; + if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { + return; + } + + const int channel = kbc0 / (iter_k*iter_j); + const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k; + + if (jt*ncols1 + j >= ne01) { + return; + } + + dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid; + + // Load the partial result that needs a fixup: + float dst_val = 0.0f; + float max_val = 0.0f; + float rowsum = 0.0f; + { + dst_val = *dst; + + const float2 tmp = dst_fixup[bidx0*ncols + jc]; + max_val = tmp.x; + rowsum = tmp.y; + } + + + // Iterate over previous blocks and compute the combined results. + // All CUDA blocks that get here must have a previous block that needs a fixup. + int bidx = bidx0 - 1; + int kbc_stop = kbc0; + while(true) { + const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + if (kbc == kbc_stop) { // Did not have any data. + bidx--; + kbc_stop = kbc; + continue; + } + + const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid]; + + const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc]; + + // Scale the current and new value accumulators depending on the max. values. + const float max_val_new = fmaxf(max_val, tmp.x); + + const float diff_val = max_val - max_val_new; + const float diff_add = tmp.x - max_val_new; + + const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f; + const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f; + + dst_val = scale_val*dst_val + scale_add*dst_add; + rowsum = scale_val*rowsum + scale_add*tmp.y; + + max_val = max_val_new; + + // If this block started in a previous tile we are done and don't need to combine additional partial results. + if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) { + break; + } + bidx--; + kbc_stop = kbc; + } + + // Write back final result: + *dst = dst_val / rowsum; +} + +template<int D> // D == head size +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +__launch_bounds__(D, 1) +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +static __global__ void flash_attn_mma_combine_results( + const float * __restrict__ VKQ_parts, + const float2 * __restrict__ VKQ_meta, + float * __restrict__ dst, + const int parallel_blocks) { + VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x; + VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x; + dst += D * gridDim.z*blockIdx.x; + + const int tid = threadIdx.x; + __builtin_assume(tid < D); + + extern __shared__ float2 meta[]; + if (tid < 2*parallel_blocks) { + ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid]; + } + + __syncthreads(); + + float kqmax = meta[0].x; + for (int l = 1; l < parallel_blocks; ++l) { + kqmax = max(kqmax, meta[l].x); + } + + float VKQ_numerator = 0.0f; + float VKQ_denominator = 0.0f; + for (int l = 0; l < parallel_blocks; ++l) { + const float diff = meta[l].x - kqmax; + float KQ_max_scale = expf(diff); + const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); + *((uint32_t *) &KQ_max_scale) &= ftz_mask; + + VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid]; + VKQ_denominator += KQ_max_scale * meta[l].y; + } + + dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator; +} + +template <int D, int ncols1, int ncols2, int KQ_stride> +void launch_fattn_mma( + ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, + const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE +) { + constexpr int ncols = ncols1 * ncols2; + + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + const ggml_tensor * mask = dst->src[3]; + + ggml_tensor * KQV = dst; + + GGML_ASSERT(Q->type == GGML_TYPE_F32); + GGML_ASSERT(KQV->type == GGML_TYPE_F32); + + GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); + GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && + "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); + + GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); + + GGML_ASSERT(Q->ne[3] == 1); + + ggml_cuda_pool & pool = ctx.pool(); + cudaStream_t main_stream = ctx.stream(); + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const int nsm = ggml_cuda_info().devices[id].nsm; + + ggml_cuda_pool_alloc<half> K_f16(pool); + ggml_cuda_pool_alloc<half> V_f16(pool); + ggml_cuda_pool_alloc<float> dst_tmp(pool); + ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool); + + const char * K_data = (const char *) K->data; + size_t nb11 = K->nb[1]; + size_t nb12 = K->nb[2]; + size_t nb13 = K->nb[3]; + + const char * V_data = (const char *) V->data; + size_t nb21 = V->nb[1]; + size_t nb22 = V->nb[2]; + size_t nb23 = V->nb[3]; + + if (need_f16_K && K->type != GGML_TYPE_F16) { + K_f16.alloc(ggml_nelements(K)); + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type); + to_fp16(K_data, K_f16.ptr, 1, ggml_nelements(K), main_stream); + K_data = (char *) K_f16.ptr; + + const size_t bs = ggml_blck_size(K->type); + const size_t ts = ggml_type_size(K->type); + + nb11 = nb11*bs*sizeof(half)/ts; + nb12 = nb12*bs*sizeof(half)/ts; + nb13 = nb13*bs*sizeof(half)/ts; + } + + if (need_f16_V && V->type != GGML_TYPE_F16) { + V_f16.alloc(ggml_nelements(V)); + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); + to_fp16(V_data, V_f16.ptr, 1, ggml_nelements(V), main_stream); + V_data = (char *) V_f16.ptr; + + const size_t bs = ggml_blck_size(V->type); + const size_t ts = ggml_type_size(V->type); + + nb21 = nb21*bs*sizeof(half)/ts; + nb22 = nb22*bs*sizeof(half)/ts; + nb23 = nb23*bs*sizeof(half)/ts; + } + + int parallel_blocks = 1; + + const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); + const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; + + const dim3 block_dim(warp_size, nwarps, 1); + dim3 blocks_num; + if (stream_k) { + // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. + const int max_blocks = 2*nsm; + const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks; + const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves); + + const int nblocks_stream_k = max_blocks; + + const bool use_stream_k = cc >= CC_ADA_LOVELACE || tiles_efficiency_percent < 75; + + blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total; + blocks_num.y = 1; + blocks_num.z = 1; + + dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float)); + } else { + GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0); + const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. + + int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy. + CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared)); + + // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave: + parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1); + + // parallel_blocks must not be larger than what the tensor size allows: + parallel_blocks = std::min(parallel_blocks, ntiles_KQ); + + // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects. + // Test whether parallel_blocks can be set to a higher value for better efficiency. + const int blocks_per_wave = nsm * max_blocks_per_sm; + int nwaves_best = 0; + int efficiency_percent_best = 0; + for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) { + const int nblocks_total = ntiles_total * parallel_blocks_test; + const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave; + const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave); + + // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead. + if (efficiency_percent_best >= 90 && nwaves > nwaves_best) { + break; + } + + if (efficiency_percent > efficiency_percent_best) { + nwaves_best = nwaves; + efficiency_percent_best = efficiency_percent; + parallel_blocks = parallel_blocks_test; + } + } + + blocks_num.x = ntiles_x; + blocks_num.y = parallel_blocks; + blocks_num.z = Q->ne[2]*Q->ne[3]; + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + } + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (const float *) KQV->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } + + const uint32_t n_head = Q->ne[2]; + const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head)))); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + GGML_ASSERT(block_dim.x % warp_size == 0); + fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>( + (const char *) Q->data, + K_data, + V_data, + mask ? ((const char *) mask->data) : nullptr, + !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, + scale, max_bias, m0, m1, n_head_log2, logit_softcap, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + nb11, nb12, nb13, + nb21, nb22, nb23, + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + CUDA_CHECK(cudaGetLastError()); + + if (stream_k) { + if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. + const dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2}; + + flash_attn_mma_stream_k_fixup<D, ncols1, ncols2, KQ_stride> + <<<blocks_num_combine, block_dim_combine, 0, main_stream>>> + ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]); + } + } else if (parallel_blocks > 1) { + const dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z); + const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2); + + flash_attn_mma_combine_results<D> + <<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks); + } + CUDA_CHECK(cudaGetLastError()); +} + diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh new file mode 100644 index 00000000..af38071d --- /dev/null +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -0,0 +1,1047 @@ +#include "common.cuh" +#include "cp-async.cuh" +#include "mma_new.cuh" +#include "fattn-common.cuh" + +using namespace ggml_cuda_mma; + +typedef tile<16, 8, half2> tile_A; +typedef tile< 8, 8, half2> tile_B; +typedef tile<16, 8, half2> tile_B_16; +typedef tile<16, 8, float> tile_C_KQ; +typedef tile<16, 16, float> tile_C_KQ_16; +typedef tile<16, 4, half2> tile_C_VKQ; +typedef tile<16, 8, half2> tile_C_VKQ_16; + +template<int D, int nwarps, int KQ_per_iter> +static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( + const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) { + constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. + + // If cp.async is available, load up to the highest power of 2 in D asynchronously: +#ifdef CP_ASYNC_AVAILABLE + static_assert(D >= 64 && D < 512, "bad D"); + constexpr int k0_sync_start = D/2 < 64 ? 32 : (D/2 < 128 ? 64 : 128); + + const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV); + + constexpr int preload = 64; + constexpr int h2_per_chunk = 16/sizeof(half2); + constexpr int chunks_per_row = k0_sync_start / h2_per_chunk; + constexpr int stride_i = WARP_SIZE / chunks_per_row; +#pragma unroll + for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (chunks_per_row == WARP_SIZE ? 0 : threadIdx.x / chunks_per_row); + const int k = (chunks_per_row == WARP_SIZE ? threadIdx.x : threadIdx.x % chunks_per_row)*h2_per_chunk; + + cp_async_cg_16<preload>(tile_KV_32 + (i*D2_padded + k)*sizeof(half2), KV + i*stride_KV + k); + } +#else + constexpr int k0_sync_start = 0; +#endif // CP_ASYNC_AVAILABLE + static_assert(k0_sync_start % WARP_SIZE == 0, "bad k0_sync_start"); + + // If D is not a power of 2, the rest is loaded synchronously. + // K/V data is loaded with decreasing granularity for D for better memory bandwidth. + static_assert(KQ_per_iter % (4*nwarps) == 0, "out of bounds"); +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? k0_sync_start : D/2 - (D/2) % (2*stride_k); + const int k0_stop = D/2 - (D/2) % (1*stride_k); + const int stride_i = WARP_SIZE / stride_k; + + if (k0_start == k0_stop || k0_stop <= k0_sync_start) { + continue; + } + +#pragma unroll + for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + tile_KV[i*D2_padded + k] = KV[i*stride_KV + k]; + } + } + } +} + +template<int ncols1, int nwarps, int KQ_per_iter> +static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( + const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) { + static_assert(KQ_per_iter == 2*WARP_SIZE || KQ_per_iter == WARP_SIZE, "bad KQ_per_iter"); +#ifdef CP_ASYNC_AVAILABLE + constexpr int preload = KQ_per_iter * sizeof(half); + constexpr int cols_per_warp = 8*WARP_SIZE/KQ_per_iter; + constexpr int stride_j = nwarps * cols_per_warp; + + const unsigned int tile_mask_32 = __cvta_generic_to_shared(tile_mask); + +#pragma unroll + for (int j0 = 0; j0 < ncols1; j0 += stride_j) { + const int j = j0 + threadIdx.y*cols_per_warp + + (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/8)); + + if (j0 + stride_j > ncols1 && j >= ncols1) { + break; + } + + const int i = 4 * (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x % (WARP_SIZE/4) : threadIdx.x % (WARP_SIZE/8)); + + cp_async_cg_16<preload>(tile_mask_32 + j*(KQ_per_iter*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i); + } +#else + constexpr int cols_per_warp = 2*WARP_SIZE/KQ_per_iter; + constexpr int stride_j = nwarps * cols_per_warp; +#pragma unroll + for (int j0 = 0; j0 < ncols1; j0 += stride_j) { + const int j = j0 + threadIdx.y*cols_per_warp + (KQ_per_iter == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/2)); + + if (j0 + stride_j > ncols1 && j >= ncols1) { + break; + } + + const int i = KQ_per_iter == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/2); + + tile_mask[j*(KQ_per_iter/2 + 4) + i] = mask_h2[j*stride_mask + i]; + } +#endif // CP_ASYNC_AVAILABLE +} + +template<int D, int ncols1, int ncols2, int nwarps, int KQ_per_iter, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter> +static __device__ __forceinline__ void flash_attn_ext_f16_iter( + const float2 * const __restrict__ Q_f2, + const half2 * const __restrict__ K_h2, + const half2 * const __restrict__ V_h2, + const half2 * const __restrict__ mask_h2, + float2 * const __restrict__ dstk, + float2 * const __restrict__ dstk_fixup, + const float scale, + const float slope, + const float logit_softcap, + const int ne01, + const int ne02, + const int stride_KV, + const int stride_mask, + const int jt, + half2 * const __restrict__ tile_K, + half2 * const __restrict__ tile_V, + half2 * const __restrict__ tile_mask, + const tile_B * const __restrict__ Q_B, + tile_C_VKQ * const __restrict__ VKQ_C, + float * const __restrict__ KQ_max, + float * const __restrict__ KQ_rowsum, + const int kb0) { +#ifdef INT8_MMA_AVAILABLE + constexpr int cols_per_warp = ntiles * tile_B::I; + constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; + constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. + + const int k_VKQ_0 = kb0 * KQ_per_iter; + tile_C_KQ KQ_C[KQ_per_iter/(np*tile_C_KQ::I) * ntiles]; + + // Use wide variants of tiles if ntiles >= 2. + tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; + tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; + tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C; + +#ifdef CP_ASYNC_AVAILABLE + cp_async_wait_all(); + __syncthreads(); + flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV); +#else + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask<ncols1, nwarps, KQ_per_iter>(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask); + } + flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV); + __syncthreads(); +#endif // CP_ASYNC_AVAILABLE + + // Calculate tile of KQ: +#pragma unroll + for (int i_KQ_00 = 0; i_KQ_00 < KQ_per_iter; i_KQ_00 += np*tile_A::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += tile_A::J) { + tile_A K_A; + load_ldmatrix(K_A, tile_K + i_KQ_0*D2_padded + k_KQ_0, D2_padded); + if (ntiles == 1) { + mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]); + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + // Wide version of KQ_C is column-major => swap A and B. + mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A); + } + } + } + } + +#ifndef CP_ASYNC_AVAILABLE + __syncthreads(); // Only needed if tile_K == tile_V. +#endif // CP_ASYNC_AVAILABLE + + if (use_logit_softcap) { + static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int i = 0; i < KQ_per_iter/(np*tile_C_KQ::I) * ntiles; ++i) { +#pragma unroll + for (int l = 0; l < tile_C_KQ::ne; ++l) { + KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]); + } + } + } + + float KQ_max_new[cols_per_thread]; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + KQ_max_new[col] = KQ_max[col]; + } + float KQ_rowsum_add[cols_per_thread] = {0.0f}; + + if (ntiles == 1) { + if (ncols2 > 1 || mask_h2) { +#pragma unroll + for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ::I) { + const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I; +#pragma unroll + for (int l = 0; l < tile_C_KQ::ne; ++l) { + const int i = i0 + tile_C_KQ::get_i(l); + const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2; + + KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope * + __half2float(((const half *) tile_mask)[j*(KQ_per_iter + 8) + i]); + } + } + } + + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. + static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) { +#pragma unroll + for (int l = 0; l < tile_C_KQ::ne; ++l) { + KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]); + } + } + + // Values per KQ column are spread across 8 threads, does not need full warp reduce: +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { +#pragma unroll + for (int offset = 16; offset >= 4; offset >>= 1) { + KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); + } + } + + static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); + +#pragma unroll + for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) { +#pragma unroll + for (int l = 0; l < tile_C_KQ::ne; ++l) { + KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]); + + KQ_rowsum_add[l % 2] += KQ_C[k].x[l]; + } + } + } else { // ntiles > 1 + if (ncols2 > 1 || mask_h2) { +#pragma unroll + for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ_16::J) { + const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J; +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { +#pragma unroll + for (int l0 = 0; l0 < tile_C_KQ_16::ne; l0 += 2) { + const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2; + const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2; + + const float2 tmp = __half22float2(tile_mask[j*(KQ_per_iter/2 + 4) + i]); + const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t; + KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x; + KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y; + } + } + } + } + + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. + static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { +#pragma unroll + for (int l = 0; l < tile_C_KQ_16::ne; ++l) { + const int KQ_index = 2*t + (l/2) % 2; + KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l]); + } + } + } + + // Values per KQ column are spread across 4 threads, does not need full warp reduce: +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { +#pragma unroll + for (int offset = 2; offset >= 1; offset >>= 1) { + KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); + } + } + + static_assert(KQ_per_iter % (np*tile_C_KQ_16::J) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { +#pragma unroll + for (int l = 0; l < tile_C_KQ_16::ne; ++l) { + const int KQ_index = 2*t + (l/2) % 2; + + KQ_C_16[k*ntiles/2 + t].x[l] = expf(KQ_C_16[k*ntiles/2 + t].x[l] - KQ_max_new[KQ_index]); + + KQ_rowsum_add[KQ_index] += KQ_C_16[k*ntiles/2 + t].x[l]; + } + } + } + } + + { + float KQ_max_scale[cols_per_thread]; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + KQ_max_scale[col] = expf(KQ_max[col] - KQ_max_new[col]); + KQ_max[col] = KQ_max_new[col]; + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col]; + } + + if (ntiles == 1) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); +#pragma unroll + for (int i = 0; i < D/tile_C_VKQ::I; ++i) { +#pragma unroll + for (int l = 0; l < tile_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } + } else { +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); +#pragma unroll + for (int i = 0; i < D/tile_C_VKQ_16::J; ++i) { +#pragma unroll + for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) { + VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2; + } + } + } + } + } + + // Convert KQ C tiles into B tiles for VKQ calculation: + tile_B B[KQ_per_iter/(np*2*tile_B::J) * ntiles]; + tile_B_16 * B_16 = (tile_B_16 *) B; + static_assert(KQ_per_iter % (np*2*tile_B::J) == 0, "bad loop size"); + if (ntiles == 1) { +#pragma unroll + for (int k = 0; k < KQ_per_iter/(np*2*tile_B::J); ++k) { + B[k] = get_transposed(get_half2(KQ_C[k])); + } + } else { + for (int k = 0; k < KQ_per_iter/(np*2*tile_B_16::J); ++k) { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]); + } + } + } + +#ifdef CP_ASYNC_AVAILABLE + // Preload K tile for next iteration: + cp_async_wait_all(); + __syncthreads(); + if (!last_iter) { + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask<ncols1, nwarps, KQ_per_iter>(mask_h2 + (k_VKQ_0 + KQ_per_iter)/2, tile_mask, stride_mask); + } + flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(K_h2 + (k_VKQ_0 + KQ_per_iter)*stride_KV, tile_K, stride_KV); + } +#else + flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV); + __syncthreads(); +#endif // CP_ASYNC_AVAILABLE + + // Calculate VKQ tile: +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += tile_C_VKQ::I) { + static_assert((KQ_per_iter/2) % (np*tile_A::J) == 0, "bad loop size"); +#pragma unroll + for (int k00 = 0; k00 < KQ_per_iter/2; k00 += np*tile_A::J) { + const int k0 = k00 + (threadIdx.y % np)*tile_A::J; + + tile_A A; + load_ldmatrix_trans(A, tile_V + 2*k0*D2_padded + i_VKQ_0/2, D2_padded); + if (ntiles == 1) { + mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + // Wide version of VKQ_C is column-major => swap A and B. + mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A); + } + } + } + } + +#ifndef CP_ASYNC_AVAILABLE + __syncthreads(); // Only needed if tile_K == tile_V. +#endif // CP_ASYNC_AVAILABLE + +#else + GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); + GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); + GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); + GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_KV); + GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K); + GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K); + GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B); + GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum); + GGML_UNUSED(kb0); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE +} + +template<int D, int ncols1, int ncols2, int nwarps, int KQ_per_iter, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup> +static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( + const float2 * const __restrict__ Q_f2, + const half2 * const __restrict__ K_h2, + const half2 * const __restrict__ V_h2, + const half2 * const __restrict__ mask_h2, + float2 * const __restrict__ dstk, + float2 * const __restrict__ dstk_fixup, + const float scale, + const float slope, + const float logit_softcap, + const int ne01, + const int ne02, + const int stride_Q1, + const int stride_Q2, + const int stride_KV, + const int stride_mask, + const int jt, + const int kb0_start, + const int kb0_stop) { +#ifdef INT8_MMA_AVAILABLE + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + constexpr int ncols = ncols1 * ncols2; + constexpr int cols_per_warp = ntiles * tile_B::I; + constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; + constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + + static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps"); + + static_assert(D % nwarps == 0, "bad D"); + static_assert(KQ_per_iter % nwarps == 0, "bad KQ_per_iter"); + + constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. + + // Temporary shared buffer for loading K/V data with KQ_per_iter*D logical elements: + extern __shared__ half2 tile_K[]; +#ifdef CP_ASYNC_AVAILABLE + half2 * tile_V = tile_K + KQ_per_iter*D2_padded; +#else + half2 * tile_V = tile_K; +#endif // CP_ASYNC_AVAILABLE + half2 * tile_mask = tile_V + KQ_per_iter*D2_padded; + + tile_B Q_B[D/(2*tile_B::J) * ntiles]; + tile_C_VKQ VKQ_C[D/tile_C_VKQ::I * ntiles]; + + tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; + tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; + + float KQ_rowsum[cols_per_thread] = {0.0f}; + float KQ_max[cols_per_thread]; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + KQ_max[col] = -FLT_MAX/2.0f; + } + + // Temporarily load Q data into tile_K, will be loaded into registers afterwards. + // The loading is done with decreasing granularity for D for better memory bandwidth. + const half2 scale_h2 = make_half2(scale, scale); +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); + const int k0_stop = D/2 - (D/2) % (1*stride_k); + const int stride_jc = WARP_SIZE / stride_k; + + if (k0_start == k0_stop) { + continue; + } + +#pragma unroll + for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) { + const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + + if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) { + break; + } + + const int j = jc / ncols2; + const int c = jc % ncols2; + + if (jt*ncols1 + j < ne01) { +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k]; + tile_K[jc*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y); + } + } else { +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + tile_K[jc*D2_padded + k] = make_half2(0.0f, 0.0f); + } + } + } + } + + __syncthreads(); + + { + const int j0 = (threadIdx.y / np) * cols_per_warp; + +#pragma unroll + for (int k0 = 0; k0 < D/2; k0 += tile_B::J) { + if (ntiles == 1) { + load_ldmatrix(Q_B[k0/tile_B::J], tile_K + j0*D2_padded + k0, D2_padded); + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t], + tile_K + (j0 + t*tile_B_16::I)*D2_padded + k0, D2_padded); + } + } + } + } + + __syncthreads(); + + // Preload mask and K data for first iteration when using cp_async: +#ifdef CP_ASYNC_AVAILABLE + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask<ncols1, nwarps, KQ_per_iter>(mask_h2 + kb0_start*KQ_per_iter/2, tile_mask, stride_mask); + } + flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(K_h2 + kb0_start*KQ_per_iter*stride_KV, tile_K, stride_KV); +#endif // CP_ASYNC_AVAILABLE + + // Iterate over ne11 == previous tokens: + for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) { + constexpr bool last_iter = false; + flash_attn_ext_f16_iter<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter> + (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); + } + { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. + constexpr bool last_iter = true; + flash_attn_ext_f16_iter<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter> + (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); + } + + // With cp_async there is no __syncthreads at the end of the iter, + // there can be a race condition on shared memory access for combining/writing back results. +#ifdef CP_ASYNC_AVAILABLE + if (nwarps*cols_per_warp > KQ_per_iter) { + __syncthreads(); + } +#endif // CP_ASYNC_AVAILABLE + + // Finally, sum up partial KQ rowsums. + // The partial sums are spread across 8/4 threads each, does not need full reduce. + { + constexpr int offset_first = ntiles == 1 ? 16 : 2; + constexpr int offset_last = ntiles == 1 ? 4 : 1; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { +#pragma unroll + for (int offset = offset_first; offset >= offset_last; offset >>= 1) { + KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE); + } + } + } + + // Write VKQ accumulators to shared memory in column-major format. + // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. + // Also for np > 1 the combination is done via these values in shared memory. + if (ntiles == 1) { + const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data +#pragma unroll + for (int k0 = 0; k0 < D/2; k0 += tile_B::J) { + const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); // Conversion of C to B matrix puts it in column-major format. + +#pragma unroll + for (int l = 0; l < tile_B::ne; ++l) { + const int k = k0 + tile_B::get_j(l); + + tile_K[jc_cwd*D2_padded + k] = B.x[l]; + } + } + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I; +#pragma unroll + for (int k0 = 0; k0 < D/2; k0 += tile_C_VKQ_16::J) { +#pragma unroll + for (int l = 0; l < tile_C_VKQ_16::ne; ++l) { + const int j = j0 + tile_C_VKQ_16::get_i(l); + const int k = k0 + tile_C_VKQ_16::get_j(l); + + tile_K[j*D2_padded + k] = VKQ_C_16[k0/tile_C_VKQ_16::J * ntiles/2 + t].x[l]; + } + } + } + } + + if constexpr (ntiles == 1) { + const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset + const int jc_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta + const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); // KQ combine max rowsum + + if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) { + // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. + ((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr; + } + + __syncthreads(); + + if (np == 1) { + // No combination is needed, the meta data can be directly written from registers to VRAM. + if (needs_fixup && threadIdx.x < tile_B::I) { + float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } + if (is_fixup && threadIdx.x < tile_B::I) { + float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } + } + } else { + static_assert(ntiles == 2 || ntiles == 4, "bad ntiles"); + const int jc_cwm = threadIdx.y*cols_per_warp // jc combine write meta + + (ntiles == 4 ? ((threadIdx.x % 4) / 2) * tile_C_VKQ_16::I : 0) + + tile_C_VKQ_16::get_i(threadIdx.x % 4); + const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); // KQ combine max rowsum + + if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) { + // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. + ((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr; + } + + __syncthreads(); + + if (np == 1) { + // No combination is needed, the meta data can be directly written from registers to VRAM. + if (needs_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) { + float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } + if (is_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) { + float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } + } + } + + static_assert(np == 1 || ntiles == 1 || ntiles == 2, "bad ntiles"); + if (np > 1 && threadIdx.y % np == 0) { + // Combine the meta data for parallel warps via shared memory. + // Warps with threadIdx.y % np != 0 must NOT return early. + // All threads must return simultaneously to avoid race conditions with work on the next tile. + + constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1; + + const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x); + float2 * const meta_ptr = ((float2 *) tile_K) + jc_meta*(D2_padded/2) + D/4; + float2 meta[nmeta]; +#pragma unroll + for (int imeta = 0; imeta < nmeta; ++imeta) { + meta[imeta] = meta_ptr[imeta * WARP_SIZE * D2_padded/2]; + } + + float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps. +#pragma unroll + for (int imeta = 1; imeta < nmeta; ++imeta) { + KQ_cmn = fmaxf(KQ_cmn, meta[imeta].x); + } +#pragma unroll + for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { + if (offset >= WARP_SIZE) { + continue; + } + KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); + } + + float KQ_cms[nmeta]; // KQ combine max scale per warp. +#pragma unroll + for (int imeta = 0; imeta < nmeta; ++imeta) { + KQ_cms[imeta] = expf(meta[imeta].x - KQ_cmn); + } + + float KQ_crs = KQ_cms[0]*meta[0].y; // KQ combine rowsum, scaled sum of all parallel warps. +#pragma unroll + for (int imeta = 1; imeta < nmeta; ++imeta) { + KQ_crs += KQ_cms[imeta]*meta[imeta].y; + } +#pragma unroll + for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { + if (offset >= WARP_SIZE) { + continue; + } + KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); + } + + // Write back combined meta data: +#pragma unroll + for (int imeta = 0; imeta < nmeta; ++imeta) { + if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) { + // Combined KQ max scale + rowsum. + meta_ptr[imeta * WARP_SIZE * D2_padded/2] = make_float2(KQ_cms[imeta], KQ_crs); + } + } + + // Combined KQ max + rowsum. + static_assert(cols_per_warp <= WARP_SIZE); + if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) { + float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; + dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); + } + if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) { + float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; + dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); + } + } + + if (np > 1) { + __syncthreads(); + } + + if (np == 1 || threadIdx.y % np == 0) { + // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums. + // The values after that are for the partial results of the individual blocks. + float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(D/2)); + +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); + const int k0_stop = D/2 - (D/2) % (1*stride_k); + const int stride_jc = WARP_SIZE / stride_k; + + if (k0_start == k0_stop) { + continue; + } + +#pragma unroll + for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) { + const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + + if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) { + break; + } + + const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp; + + const int j_dst = jc_dst / ncols2; + const int c_dst = jc_dst % ncols2; + + if (!is_fixup && jt*ncols1 + j_dst >= ne01) { + continue; + } + + const float * meta_j = (const float *) tile_K + jc_tile_K*D2_padded + D/2; +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + float2 dstk_val = make_float2(0.0f, 0.0f); +#pragma unroll + for (int ip = 0; ip < np; ++ip) { + const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * D2_padded + 0]; + const float2 dstk_val_add = __half22float2(tile_K[(jc_tile_K + ip*cols_per_warp) * D2_padded + k]); + dstk_val.x += dstk_val_add.x*KQ_crs; + dstk_val.y += dstk_val_add.y*KQ_crs; + } + + if (!needs_fixup && !is_fixup) { + const float KQ_rowsum_j = meta_j[1]; + dstk_val.x /= KQ_rowsum_j; + dstk_val.y /= KQ_rowsum_j; + } + + if (is_fixup) { + dstk_fixup_data[jc_dst*(D/2) + k] = dstk_val; + } else { + dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(D/2) + k] = dstk_val; + } + } + } + } + } + + if (np > 1) { + __syncthreads(); + } +#else + GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); + GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); + GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); + GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1); + GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_KV); GGML_UNUSED(stride_mask); + GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE +} + +template<int D, int ncols1, int ncols2, int nwarps, int KQ_per_iter, int ntiles, bool use_logit_softcap> +__launch_bounds__(nwarps*WARP_SIZE, 2) +static __global__ void flash_attn_mma_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const float logit_softcap, + const uint32_t n_head_log2, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int nb21, + const int nb22, + const int nb23, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { +#if defined(INT8_MMA_AVAILABLE) + + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + + static_assert(FATTN_KQ_STRIDE % KQ_per_iter == 0, "bad KQ_per_iter"); + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + + const int stride_Q1 = nb01 / sizeof(float2); + const int stride_Q2 = nb02 / sizeof(float2); + const int stride_KV = nb11 / sizeof(half2); + const int stride_mask = nb31 / sizeof(half2); + + const int iter_k = ne11 / FATTN_KQ_STRIDE; + const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; + + constexpr int kb_niter = FATTN_KQ_STRIDE / KQ_per_iter; // Number of kernel iterations per assigned KQ slice. + + // kbc == k block continuous, current index in continuous ijk space. + int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + + // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined. + // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup). + // In the most general case >2 seams can fall into the same tile. + + // kb0 == k start index when in the output tile. + int kb0_start = kbc % iter_k; + int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc); + while (kbc < kbc_stop && kb0_stop == iter_k) { + const int channel = kbc / (iter_k*iter_j); + const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. + + const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); + const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape + const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; + float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2); + + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; + + const int kb0_start_kernel = kb0_start * kb_niter; + const int kb0_stop_kernel = kb0_stop * kb_niter; + + constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. + if (kb0_start == 0) { + constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. + flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup> + (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + } else { + constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. + flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup> + (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + } + + kbc += iter_k; + kbc -= kbc % iter_k; + + kb0_start = 0; + kb0_stop = min(iter_k, kbc_stop - kbc); + } + + if (kbc >= kbc_stop) { + return; + } + + const int channel = kbc / (iter_k*iter_j); + const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. + + const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); + const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape + const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; + float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2); + + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; + + const int kb0_start_kernel = kb0_start * kb_niter; + const int kb0_stop_kernel = kb0_stop * kb_niter; + + constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. + constexpr bool needs_fixup = false; + flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup> + (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); +#else + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); + GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); + GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); + GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); + GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); + GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); + GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); + GGML_UNUSED(ne2); GGML_UNUSED(ne3); + NO_DEVICE_CODE; +#endif // defined(INT8_MMA_AVAILABLE) +} + +template <int D, int ncols1, int ncols2> +void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + constexpr int ncols = ncols1 * ncols2; + constexpr int KQ_per_iter = D <= 128 && ncols1 <= 64 ? 64 : 32; + constexpr int nwarps = (KQ_per_iter == 32 && ncols <= 16) ? 2 : 4; + constexpr int ntiles = ncols <= 8 ? 1 : (ncols <= 64 ? 2 : 4); + constexpr int cols_per_warp = ntiles * tile_B::I; + + static_assert(D % tile_B::J == 0, "bad D"); + static_assert(ncols % cols_per_warp == 0, "bad ncols"); + + const ggml_tensor * KQV = dst; + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + + const int KQ_shared_rows = cp_async_available(cc) ? 2*KQ_per_iter : KQ_per_iter; + + const size_t nbytes_shared_KV = KQ_shared_rows * (D + 8) * sizeof(half); + const size_t nbytes_shared_mask = ncols1 * (KQ_per_iter + 8) * sizeof(half); + const size_t nbytes_shared_combine = nwarps*cols_per_warp * (D + 8) * sizeof(half); + + const size_t nbytes_shared_total = std::max(nbytes_shared_KV + nbytes_shared_mask, nbytes_shared_combine); + + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + fattn_kernel_t fattn_kernel; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + fattn_kernel = flash_attn_mma_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>; + } else { + constexpr bool use_logit_softcap = true; + fattn_kernel = flash_attn_mma_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>; + } + + launch_fattn_mma<D, ncols1, ncols2, KQ_per_iter> + (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true); +} + + +#define DECL_FATTN_MMA_F16_CASE(D, ncols1, ncols2) \ + template void ggml_cuda_flash_attn_ext_mma_f16_case \ + <D, ncols1, ncols2>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +#define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(D, ncols) \ + extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/1, 1); \ + extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/2, 2); \ + extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/4, 4); \ + extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/8, 8); \ + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 8) + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 16) + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 32) + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 64) + +// Kernels with ncols == 128 are only 4% faster due to register pressure. +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128) +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128) +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128) +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128) +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128) +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128) // Needs too much shared memory. diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 55116058..fce66c20 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -12,10 +12,94 @@ #include "fattn-vec-f16.cuh" #include "fattn-vec-f32.cuh" #include "fattn-wmma-f16.cuh" +#include "fattn-mma-f16.cuh" #include "fattn.cuh" #include <cstdint> +template <int D, int ncols2> +static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + + if (Q->ne[1] <= 8/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case<D, 8/ncols2, ncols2>(ctx, dst); + return; + } + + if (Q->ne[1] <= 16/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case<D, 16/ncols2, ncols2>(ctx, dst); + return; + } + + if (Q->ne[1] <= 32/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case<D, 32/ncols2, ncols2>(ctx, dst); + return; + } + + ggml_cuda_flash_attn_ext_mma_f16_case<D, 64/ncols2, ncols2>(ctx, dst); +} + +template <int ncols2> +static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst); + break; + case 80: + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst); + break; + case 112: + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst); + break; + case 256: + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst); + break; + default: + GGML_ABORT("fatal error"); + break; + } +} + +static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * mask = dst->src[3]; + + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + const float use_gqa_opt = mask && max_bias == 0.0f; + + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + + if (use_gqa_opt && gqa_ratio % 8 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio == 4) { + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio == 2) { + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst); + return; + } + + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst); +} + static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; @@ -371,8 +455,11 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg } void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; ggml_cuda_set_device(ctx.device); const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; @@ -389,7 +476,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } if (!fast_fp16_available(cc)) { - if (Q->ne[1] <= 8) { + if (Q->ne[1] <= 8 || Q->ne[0] == 256) { ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); } else { ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); @@ -398,23 +485,43 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } if (!fp16_mma_available(cc)) { - if (Q->ne[1] <= 8) { - ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); + if (precision == GGML_PREC_DEFAULT) { + if (Q->ne[1] <= 8 || Q->ne[0] == 256) { + ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); + } } else { - ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); + if (Q->ne[1] <= 8 || Q->ne[0] == 256) { + ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); + } } return; } - if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) { + const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations + // So, not sure why in mainline they thought that for CC_ADA_LOVELACE or when KV cache is not f16 the vector kernels are faster. + // On my GPU (RTX-4080) MMA is efinitely faster for GQA, both for f16 and for quantized KV cache. + //const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; + //const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < CC_ADA_LOVELACE && !mma_needs_data_conversion; + const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies; + const bool can_use_vector_kernel = Q->ne[0] % (2*WARP_SIZE) == 0; + if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) { if (precision == GGML_PREC_DEFAULT) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); - return; - } else if(Q->ne[0] <= 128) { + } else { ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); - return; } + return; + } + + // We need this because I haven't adapted the MMA kernels to work for different + // K and V head sizes. + if (K->ne[0] != V->ne[0]) { + ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); } - ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16(ctx, dst); } diff --git a/ggml/src/ggml-cuda/mma_new.cuh b/ggml/src/ggml-cuda/mma_new.cuh new file mode 100644 index 00000000..5bba41c6 --- /dev/null +++ b/ggml/src/ggml-cuda/mma_new.cuh @@ -0,0 +1,396 @@ +// This file contains primitives that expose the tensor core PTX instructions for CUDA code. +// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout. +// The documentation for the PTX instructions can be found under: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction +// +// Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C. +// A is a row-major matrix with shape M x K. +// B is a column-major matrix with shape K x N. +// C is a column-major matrix with shape M x N. +// A, B, and C are represented using the same fundamental data type: a row-major matrix with I rows and J columns. +// Note that J is measured in physical 32 bit elements instead of logical elements. +// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile. +// All matrix tiles have ne physical 32 bit elements per warp. +// +// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes. + +#include "common.cuh" + + +#if CUDART_VERSION >= 11080 + +static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { + int ret = 0; + +#ifdef INT8_MMA_AVAILABLE + asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" + : "=r"(ret) : "r"(x)); +#else + GGML_UNUSED(x); + NO_DEVICE_CODE; +#endif // defined(INT8_MMA_AVAILABLE) + return ret; +} + +#else + +static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { + // Imagine transposing row-major matrix to column-major matrix. + const int src_i_low = 2 * (threadIdx.x % 4); + const int src_i_high = src_i_low + 1; + const int src_j = threadIdx.x / 4; + + const int src_laneid_low = src_i_low * 4 + src_j / 2; + const int src_laneid_high = src_i_high * 4 + src_j / 2; + + const int shift_low = ((src_j + 0) % 2) * 16; + const int shift_high = ((src_j + 1) % 2) * 16; + + const int ret_low = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF; + const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000; + + return ret_low | ret_high; +} + +#endif // CUDART_VERSION >= 11080 + +static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) { + half2 ret; + *((int *) &ret) = ggml_cuda_movmatrix(*((const int *) &x)); + return ret; +} + +namespace ggml_cuda_mma { + + template <int I_, int J_, typename T> + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr int ne = I * J / WARP_SIZE; + T x[ne] = {0}; + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 8 && (J == 4 || J == 8)) { + return threadIdx.x / 4; + } else if constexpr (I == 16 && J == 8) { + return (l / 2) * 8 + threadIdx.x / 4; + } else if constexpr (I == 16 && J == 16) { + return ((l / 2) % 2) * 8 + threadIdx.x / 4; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 8 && J == 4) { + return threadIdx.x % 4; + } else if constexpr (I == 8 && J == 8) { + return 4 * l + threadIdx.x % 4; + } else if constexpr (I == 16 && J == 8) { + return 2 * (threadIdx.x % 4) + l % 2; + } else if constexpr (I == 16 && J == 16) { + return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } + }; + + template <int I_, int J_> + struct tile<I_, J_, half2> { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr int ne = I * J / WARP_SIZE; + half2 x[ne] = {{0.0f, 0.0f}}; + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 8 && J == 8) { + return threadIdx.x / 4; + } else if constexpr (I == 16 && J == 4) { + return l * 8 + threadIdx.x / 4; + } else if constexpr (I == 16 && J == 8) { + return (l % 2) * 8 + threadIdx.x / 4; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 8 && J == 8) { + return l * 4 + threadIdx.x % 4; + } else if constexpr (I == 16 && J == 4) { + return threadIdx.x % 4; + } else if constexpr (I == 16 && J == 8) { + return (l / 2) * 4 + threadIdx.x % 4; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } + }; + + template <int I, int J> + static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) { + tile<I, J/2, half2> ret; +#pragma unroll + for (int l0 = 0; l0 < tile_float.ne; l0 += 2) { + ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]); + } + return ret; + } + + static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) { + tile<8, 8, half2> ret; + ret.x[0] = ggml_cuda_movmatrix(t.x[0]); + ret.x[1] = ggml_cuda_movmatrix(t.x[1]); + + return ret; + } + + template <int I, int J, typename T> + static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) { +#pragma unroll + for (int l = 0; l < t.ne; ++l) { + t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; + } + } + + template <typename T> + static __device__ __forceinline__ void load_ldmatrix( + tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) { +#ifdef INT8_MMA_AVAILABLE + int * xi = (int *) t.x; + const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J; + asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" + : "=r"(xi[0]), "=r"(xi[1]) + : "l"(xs)); +#else + load_generic(t, xs0, stride); +#endif // INT8_MMA_AVAILABLE + } + + template <typename T> + static __device__ __forceinline__ void load_ldmatrix( + tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) { +#ifdef INT8_MMA_AVAILABLE + int * xi = (int *) t.x; + const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride; + asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" + : "=r"(xi[0]), "=r"(xi[1]) + : "l"(xs)); +#else + load_generic(xs0, stride); + GGML_UNUSED(t); +#endif // INT8_MMA_AVAILABLE + } + + template <typename T> + static __device__ __forceinline__ void load_ldmatrix( + tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { +#ifdef INT8_MMA_AVAILABLE + int * xi = (int * ) t.x; + const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3]) + : "l"(xs)); +#else + load_generic(t, xs0, stride); +#endif // INT8_MMA_AVAILABLE + } + + template <typename T> + static __device__ __forceinline__ void load_ldmatrix_trans( + tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { +#ifdef INT8_MMA_AVAILABLE + int * xi = (int * ) t.x; + const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3]) + : "l"(xs)); +#else + GGML_UNUSED(t); + GGML_UNUSED(xs0); + GGML_UNUSED(stride); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) { +#ifdef INT8_MMA_AVAILABLE +#if __CUDA_ARCH__ >= CC_AMPERE + asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[0]), "r"(A.x[1]), "r"(B.x[0])); +#else + // On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead: + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(D.x[0]), "+r"(D.x[1]) + : "r"(A.x[0]), "r"(B.x[0])); + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[1]), "r"(B.x[0])); +#endif // __CUDA_ARCH__ >= CC_AMPERE +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) { +#ifdef INT8_MMA_AVAILABLE +#if __CUDA_ARCH__ >= CC_AMPERE + asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[0]), "r"(A.x[1]), "r"(A.x[2]), "r"(A.x[3]), "r"(B.x[0]), "r"(B.x[1])); +#else + // On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead: + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(D.x[0]), "+r"(D.x[1]) + : "r"(A.x[0]), "r"(B.x[0])); + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[1]), "r"(B.x[0])); + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(D.x[0]), "+r"(D.x[1]) + : "r"(A.x[2]), "r"(B.x[1])); + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[3]), "r"(B.x[1])); +#endif // __CUDA_ARCH__ >= CC_AMPERE +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) { +#ifdef INT8_MMA_AVAILABLE + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; +#if __CUDA_ARCH__ >= CC_AMPERE + asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};" + : "+r"(Dxi[0]), "+r"(Dxi[1]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); +#else + // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead: + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[0]), "+r"(Dxi[1]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[0]), "+r"(Dxi[1]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1])); +#endif // __CUDA_ARCH__ >= CC_AMPERE +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) { +#ifdef INT8_MMA_AVAILABLE + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; +#if __CUDA_ARCH__ >= CC_AMPERE + asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};" + : "+r"(Dxi[0]), "+r"(Dxi[1]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2])); + asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};" + : "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3])); +#else + // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead: + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[0]), "+r"(Dxi[1]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[0]), "+r"(Dxi[1]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2])); + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1])); + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3])); +#endif // __CUDA_ARCH__ >= CC_AMPERE +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) { +#ifdef INT8_MMA_AVAILABLE + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; +#if __CUDA_ARCH__ >= CC_AMPERE + asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); +#else + // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead: + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1])); +#endif // __CUDA_ARCH__ >= CC_AMPERE +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) { +#ifdef INT8_MMA_AVAILABLE + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; +#if __CUDA_ARCH__ >= CC_AMPERE + asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2])); + asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3])); +#else + // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead: + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2])); + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1])); + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3])); +#endif // __CUDA_ARCH__ >= CC_AMPERE +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE + } +} diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu new file mode 100644 index 00000000..80108615 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 1, 8); +DECL_FATTN_MMA_F16_CASE(80, 1, 8); +DECL_FATTN_MMA_F16_CASE(96, 1, 8); +DECL_FATTN_MMA_F16_CASE(112, 1, 8); +DECL_FATTN_MMA_F16_CASE(128, 1, 8); +DECL_FATTN_MMA_F16_CASE(256, 1, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu new file mode 100644 index 00000000..66161c0a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 16, 1); +DECL_FATTN_MMA_F16_CASE(80, 16, 1); +DECL_FATTN_MMA_F16_CASE(96, 16, 1); +DECL_FATTN_MMA_F16_CASE(112, 16, 1); +DECL_FATTN_MMA_F16_CASE(128, 16, 1); +DECL_FATTN_MMA_F16_CASE(256, 16, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu new file mode 100644 index 00000000..ee88c72a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 16, 2); +DECL_FATTN_MMA_F16_CASE(80, 16, 2); +DECL_FATTN_MMA_F16_CASE(96, 16, 2); +DECL_FATTN_MMA_F16_CASE(112, 16, 2); +DECL_FATTN_MMA_F16_CASE(128, 16, 2); +DECL_FATTN_MMA_F16_CASE(256, 16, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu new file mode 100644 index 00000000..d888a5a4 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 16, 4); +DECL_FATTN_MMA_F16_CASE(80, 16, 4); +DECL_FATTN_MMA_F16_CASE(96, 16, 4); +DECL_FATTN_MMA_F16_CASE(112, 16, 4); +DECL_FATTN_MMA_F16_CASE(128, 16, 4); +DECL_FATTN_MMA_F16_CASE(256, 16, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu new file mode 100644 index 00000000..d93a2d08 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 2, 4); +DECL_FATTN_MMA_F16_CASE(80, 2, 4); +DECL_FATTN_MMA_F16_CASE(96, 2, 4); +DECL_FATTN_MMA_F16_CASE(112, 2, 4); +DECL_FATTN_MMA_F16_CASE(128, 2, 4); +DECL_FATTN_MMA_F16_CASE(256, 2, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu new file mode 100644 index 00000000..617464c9 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 2, 8); +DECL_FATTN_MMA_F16_CASE(80, 2, 8); +DECL_FATTN_MMA_F16_CASE(96, 2, 8); +DECL_FATTN_MMA_F16_CASE(112, 2, 8); +DECL_FATTN_MMA_F16_CASE(128, 2, 8); +DECL_FATTN_MMA_F16_CASE(256, 2, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu new file mode 100644 index 00000000..970d2b68 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 32, 1); +DECL_FATTN_MMA_F16_CASE(80, 32, 1); +DECL_FATTN_MMA_F16_CASE(96, 32, 1); +DECL_FATTN_MMA_F16_CASE(112, 32, 1); +DECL_FATTN_MMA_F16_CASE(128, 32, 1); +DECL_FATTN_MMA_F16_CASE(256, 32, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu new file mode 100644 index 00000000..65cd377c --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 32, 2); +DECL_FATTN_MMA_F16_CASE(80, 32, 2); +DECL_FATTN_MMA_F16_CASE(96, 32, 2); +DECL_FATTN_MMA_F16_CASE(112, 32, 2); +DECL_FATTN_MMA_F16_CASE(128, 32, 2); +DECL_FATTN_MMA_F16_CASE(256, 32, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu new file mode 100644 index 00000000..f4a8bf34 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 4, 2); +DECL_FATTN_MMA_F16_CASE(80, 4, 2); +DECL_FATTN_MMA_F16_CASE(96, 4, 2); +DECL_FATTN_MMA_F16_CASE(112, 4, 2); +DECL_FATTN_MMA_F16_CASE(128, 4, 2); +DECL_FATTN_MMA_F16_CASE(256, 4, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu new file mode 100644 index 00000000..de191a8a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 4, 4); +DECL_FATTN_MMA_F16_CASE(80, 4, 4); +DECL_FATTN_MMA_F16_CASE(96, 4, 4); +DECL_FATTN_MMA_F16_CASE(112, 4, 4); +DECL_FATTN_MMA_F16_CASE(128, 4, 4); +DECL_FATTN_MMA_F16_CASE(256, 4, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu new file mode 100644 index 00000000..e8cb0e1b --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 4, 8); +DECL_FATTN_MMA_F16_CASE(80, 4, 8); +DECL_FATTN_MMA_F16_CASE(96, 4, 8); +DECL_FATTN_MMA_F16_CASE(112, 4, 8); +DECL_FATTN_MMA_F16_CASE(128, 4, 8); +DECL_FATTN_MMA_F16_CASE(256, 4, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu new file mode 100644 index 00000000..a532e962 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 64, 1); +DECL_FATTN_MMA_F16_CASE(80, 64, 1); +DECL_FATTN_MMA_F16_CASE(96, 64, 1); +DECL_FATTN_MMA_F16_CASE(112, 64, 1); +DECL_FATTN_MMA_F16_CASE(128, 64, 1); +DECL_FATTN_MMA_F16_CASE(256, 64, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu new file mode 100644 index 00000000..bf25181a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 8, 1); +DECL_FATTN_MMA_F16_CASE(80, 8, 1); +DECL_FATTN_MMA_F16_CASE(96, 8, 1); +DECL_FATTN_MMA_F16_CASE(112, 8, 1); +DECL_FATTN_MMA_F16_CASE(128, 8, 1); +DECL_FATTN_MMA_F16_CASE(256, 8, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu new file mode 100644 index 00000000..378c132e --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 8, 2); +DECL_FATTN_MMA_F16_CASE(80, 8, 2); +DECL_FATTN_MMA_F16_CASE(96, 8, 2); +DECL_FATTN_MMA_F16_CASE(112, 8, 2); +DECL_FATTN_MMA_F16_CASE(128, 8, 2); +DECL_FATTN_MMA_F16_CASE(256, 8, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu new file mode 100644 index 00000000..372641be --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 8, 4); +DECL_FATTN_MMA_F16_CASE(80, 8, 4); +DECL_FATTN_MMA_F16_CASE(96, 8, 4); +DECL_FATTN_MMA_F16_CASE(112, 8, 4); +DECL_FATTN_MMA_F16_CASE(128, 8, 4); +DECL_FATTN_MMA_F16_CASE(256, 8, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu new file mode 100644 index 00000000..9ff5968b --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 8, 8); +DECL_FATTN_MMA_F16_CASE(80, 8, 8); +DECL_FATTN_MMA_F16_CASE(96, 8, 8); +DECL_FATTN_MMA_F16_CASE(112, 8, 8); +DECL_FATTN_MMA_F16_CASE(128, 8, 8); +DECL_FATTN_MMA_F16_CASE(256, 8, 8); |