diff options
Diffstat (limited to 'ggml/src')
17 files changed, 387 insertions, 258 deletions
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 85df0694..410c6406 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3275,6 +3275,10 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons if (op->src[0]->ne[0] == 128) { return true; } + if (op->src[1]->ne[0] == 192 && op->src[2]->ne[0] == 128) { + return (op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) || + (op->src[1]->type == GGML_TYPE_Q8_0 && op->src[2]->type == GGML_TYPE_Q8_0); + } if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) { return true; } diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 0a664dbd..a46f03e5 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -52,7 +52,7 @@ typedef half (*vec_dot_KQ_f16_t)( typedef float (*vec_dot_KQ_f32_t)( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); -template<typename T, int D> +template<typename T, int Dk> static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -62,7 +62,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -92,7 +92,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( return sum; } -template<typename T, int D> +template<typename T, int Dk> static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -102,7 +102,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -142,7 +142,7 @@ static __device__ __forceinline__ int get_one_int_from_table_16(const int & q4) return *((const int *) &val0_8); } -template<typename T, int D> +template<typename T, int Dk> static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_iq4_nl( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -152,7 +152,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_iq4_nl( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -179,7 +179,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_iq4_nl( return sum; } -template<typename T, int D> +template<typename T, int Dk> static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -189,7 +189,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -226,7 +226,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( return sum; } -template<typename T, int D> +template<typename T, int Dk> static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -236,7 +236,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -277,7 +277,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( return sum; } -template<typename T, int D> +template<typename T, int Dk> static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q6_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -287,7 +287,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q6_0( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -320,7 +320,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q6_0( return sum; } -template <typename T, int D> +template <typename T, int Dk> static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -330,7 +330,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_0; @@ -353,7 +353,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( return sum; } -template <typename T, int D> +template <typename T, int Dk> static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { @@ -368,7 +368,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( half2 sum2 = make_half2(0.0f, 0.0f); #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/2; k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const half2 K_ik = K_h2[k_KQ]; @@ -384,7 +384,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( float sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/2; k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const half2 K_ik = K_h2[k_KQ]; @@ -603,29 +603,29 @@ static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ v return x[i]; } -template <int D> +template <int Dk> constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) { - return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D> : - type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D> : - type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl<half, D> : - type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> : - type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> : - type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0<half, D> : - type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> : - type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> : + return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, Dk> : + type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, Dk> : + type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl<half, Dk> : + type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, Dk> : + type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, Dk> : + type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0<half, Dk> : + type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, Dk> : + type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, Dk> : nullptr; } -template <int D> +template <int Dk> constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) { - return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D> : - type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D> : - type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl<float, D> : - type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D> : - type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D> : - type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0<float, D> : - type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D> : - type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D> : + return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, Dk> : + type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, Dk> : + type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl<float, Dk> : + type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, Dk> : + type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, Dk> : + type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0<float, Dk> : + type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, Dk> : + type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, Dk> : nullptr; } @@ -653,20 +653,20 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) { nullptr; } -template<int D, int parallel_blocks> // D == head size +template<int Dv, int parallel_blocks> // Dv == V head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) -__launch_bounds__(D, 1) +__launch_bounds__(Dv, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_combine_results( const float * __restrict__ VKQ_parts, const float2 * __restrict__ VKQ_meta, float * __restrict__ dst) { - VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x; - VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x; - dst += D * gridDim.y*blockIdx.x; + VKQ_parts += parallel_blocks*Dv * gridDim.y*blockIdx.x; + VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x; + dst += Dv * gridDim.y*blockIdx.x; const int tid = threadIdx.x; - __builtin_assume(tid < D); + __builtin_assume(tid < Dv); __shared__ float2 meta[parallel_blocks]; if (tid < 2*parallel_blocks) { @@ -690,20 +690,20 @@ static __global__ void flash_attn_combine_results( const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); *((uint32_t *) &KQ_max_scale) &= ftz_mask; - VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; + VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*Dv + blockIdx.y*Dv + tid]; VKQ_denominator += KQ_max_scale * meta[l].y; } - dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; + dst[blockIdx.y*Dv + tid] = VKQ_numerator / VKQ_denominator; } -static void on_no_fattn_vec_case(const int D) { - if (D == 64) { +static void on_no_fattn_vec_case(const int Dk, const int Dv) { + if (Dk == 64 && Dv == 64) { fprintf(stderr, "Unsupported KV type combination for head_size 64.\n"); fprintf(stderr, "By default only f16 KV cache is supported.\n"); fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for V cache quantization support.\n"); GGML_ABORT("fatal error"); - } else if (D == 128) { + } else if (Dk == 128 && Dv == 128) { fprintf(stderr, "Unsupported KV type combination for head_size 128.\n"); fprintf(stderr, "Supported combinations:\n"); fprintf(stderr, " - K == q4_0, V == q4_0, 4.5 BPV\n"); @@ -715,14 +715,22 @@ static void on_no_fattn_vec_case(const int D) { fprintf(stderr, " - K == f16, V == f16, 16.0 BPV\n"); fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, iq4_nl, q5_0, q5_1, q8_0, and f16.\n"); GGML_ABORT("fatal error"); + } + else if (Dk == 192 && Dv == 128) { + fprintf(stderr, "Unsupported KV type combination for head_sizes 192 / 128\n"); + // TODO: add what is supported + } + else if (Dk == 576 && Dv == 512) { + fprintf(stderr, "Unsupported KV type combination for head_sizes 576 / 512\n"); + // TODO: add what is supported } else { - fprintf(stderr, "Unsupported KV type combination for head_size 256.\n"); + fprintf(stderr, "Unsupported KV type combination for head_sizes %d, %d.\n", Dk, Dv); fprintf(stderr, "Only f16 is supported.\n"); GGML_ABORT("fatal error"); } } -template <int D, int parallel_blocks> +template <int Dk, int Dv, int parallel_blocks> void launch_fattn( ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V @@ -838,11 +846,11 @@ void launch_fattn( return; } - const dim3 block_dim_combine(D, 1, 1); + const dim3 block_dim_combine(Dv, 1, 1); const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); const int shmem_combine = 0; - flash_attn_combine_results<D, parallel_blocks> + flash_attn_combine_results<Dv, parallel_blocks> <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>> (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); CUDA_CHECK(cudaGetLastError()); diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index d1bbf01f..bf2a4521 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -291,13 +291,13 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * constexpr int D = 64; constexpr int nwarps = 8; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_softcap>; - launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + launch_fattn<D, D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_softcap>; - launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + launch_fattn<D, D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; default: { GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index 25908d7a..28846561 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -4,7 +4,7 @@ #define FATTN_KQ_STRIDE_TILE_F32 32 -template<int D, int ncols, int nwarps, int parallel_blocks, bool use_softcap> // D == head size +template<int Dk, int Dv, int ncols, int nwarps, int parallel_blocks, bool use_softcap> // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -44,8 +44,9 @@ static __global__ void flash_attn_tile_ext_f32( const int ne1, const int ne2, const int ne3) { + static_assert(Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512)); // Skip unused kernel variants for faster compilation: - if (use_softcap && !(D == 128 || D == 256)) { + if (use_softcap && !(Dk == 128 || Dk == 256)) { NO_DEVICE_CODE; return; } @@ -61,15 +62,22 @@ static __global__ void flash_attn_tile_ext_f32( const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape const half * maskh = (const half *) mask + ne11*ic0; - const int stride_KV2 = nb11 / sizeof(half2); + const int stride_K2 = nb11 / sizeof(half2); + const int stride_V2 = nb12 / sizeof(half2); const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); - static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); + // TODO: is it Dk or Dv or both that need to be multiple of 2*WARP_SIZE ? + // let's assume it is is both. + static_assert(Dk % (2*WARP_SIZE) == 0, "Dk not divisible by 2*WARP_SIZE == 64."); + static_assert(Dv % (2*WARP_SIZE) == 0, "Dv not divisible by 2*WARP_SIZE == 64."); + + constexpr int Dkv = Dk < Dv ? Dv : Dk; // let's use this when we don't understand if it is Dk or Dv __shared__ float KQ[ncols*FATTN_KQ_STRIDE_TILE_F32]; - __shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][D + 1]; // Pad D to avoid memory bank conflicts. + // This is being used to store either K or V data. Hence we need max(Dk, Dv) as the dimension + __shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][Dkv + 1]; // Pad D to avoid memory bank conflicts. float2 * KV_tmp2 = (float2 *) KV_tmp; float kqmax[ncols/nwarps]; @@ -79,16 +87,16 @@ static __global__ void flash_attn_tile_ext_f32( } float kqsum[ncols/nwarps] = {0.0f}; - float2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}}; + float2 VKQ[ncols/nwarps][(Dv/2)/WARP_SIZE] = {{{0.0f, 0.0f}}}; // Convert Q to half2 and store in registers: - __shared__ float Q_f[ncols][D]; + __shared__ float Q_f[ncols][Dk]; #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) { + for (int i0 = 0; i0 < Dk; i0 += 2*WARP_SIZE) { float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x] : make_float2(0.0f, 0.0f); Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale; Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale; @@ -112,8 +120,8 @@ static __global__ void flash_attn_tile_ext_f32( const int i_KQ = i_KQ_0 + threadIdx.y; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) { - const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x]; + for (int k_KQ_0 = 0; k_KQ_0 < Dk; k_KQ_0 += 2*WARP_SIZE) { + const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_K2 + k_KQ_0/2 + threadIdx.x]; KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp); KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp); } @@ -124,7 +132,7 @@ static __global__ void flash_attn_tile_ext_f32( float sum[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE][ncols/nwarps] = {{0.0f}}; #pragma unroll - for (int k_KQ = 0; k_KQ < D; ++k_KQ) { + for (int k_KQ = 0; k_KQ < Dk; ++k_KQ) { float K_k[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE]; float Q_k[ncols/nwarps]; @@ -193,7 +201,7 @@ static __global__ void flash_attn_tile_ext_f32( kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add; #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) { VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale; VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale; } @@ -206,11 +214,11 @@ static __global__ void flash_attn_tile_ext_f32( const int k = k0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - KV_tmp2[k*(D/2) + i].x = __low2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]); - KV_tmp2[k*(D/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]); + KV_tmp2[k*(Dv/2) + i].x = __low2float(V_h2[(k_VKQ_0 + k)*stride_V2 + i]); + KV_tmp2[k*(Dv/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_V2 + i]); } } @@ -218,14 +226,14 @@ static __global__ void flash_attn_tile_ext_f32( #pragma unroll for (int k = 0; k < FATTN_KQ_STRIDE_TILE_F32; ++k) { - float2 V_k[(D/2)/WARP_SIZE]; + float2 V_k[(Dv/2)/WARP_SIZE]; float KQ_k[ncols/nwarps]; #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - V_k[i0/WARP_SIZE] = KV_tmp2[k*(D/2) + i]; + V_k[i0/WARP_SIZE] = KV_tmp2[k*(Dv/2) + i]; } #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { @@ -235,7 +243,7 @@ static __global__ void flash_attn_tile_ext_f32( } #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { VKQ[j0/nwarps][i0/WARP_SIZE].x += V_k[i0/WARP_SIZE].x*KQ_k[j0/nwarps]; @@ -259,7 +267,7 @@ static __global__ void flash_attn_tile_ext_f32( kqsum_j = warp_reduce_sum(kqsum_j); #pragma unroll - for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) { + for (int i00 = 0; i00 < Dv; i00 += 2*WARP_SIZE) { const int i0 = i00 + 2*threadIdx.x; float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)]; @@ -268,8 +276,8 @@ static __global__ void flash_attn_tile_ext_f32( dst_val.y /= kqsum_j; } const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; - dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = dst_val.x; - dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = dst_val.y; + dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + i0 + 0] = dst_val.x; + dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + i0 + 1] = dst_val.y; } if (parallel_blocks != 1 && threadIdx.x == 0) { @@ -285,14 +293,14 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * case 64: { constexpr int D = 64; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_softcap>; - launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, D, cols_per_block, nwarps, parallel_blocks, use_softcap>; + launch_fattn<D, D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_softcap>; - launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, D, cols_per_block, nwarps, parallel_blocks, use_softcap>; + launch_fattn<D, D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; default: { GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 7f14e78b..b6ba07e4 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -1,9 +1,9 @@ #include "common.cuh" #include "fattn-common.cuh" -template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap> // D == head size +template<int Dk, int Dv, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap> // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) -__launch_bounds__(D, 1) +__launch_bounds__(Dk, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_vec_ext_f16( const char * __restrict__ Q, @@ -43,14 +43,15 @@ static __global__ void flash_attn_vec_ext_f16( const int ne3) { #ifdef FP16_AVAILABLE // Skip unused kernel variants for faster compilation: - if (use_softcap && !(D == 128 || D == 256)) { + if constexpr (Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512)) { + if (use_softcap && !(Dk == 128 || Dk == 256)) { NO_DEVICE_CODE; return; } //In this kernel Q, K, V are matrices while i, j, k are matrix indices. - constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K); + constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<Dk>(type_K); constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; constexpr dequantize_1_f16_t dequantize_1_v = get_dequantize_1_f16(type_V); @@ -67,12 +68,13 @@ static __global__ void flash_attn_vec_ext_f16( const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); - static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); - constexpr int nwarps = D / WARP_SIZE; + static_assert(Dk % (2*WARP_SIZE) == 0, "Dk not divisible by 2*WARP_SIZE == 64."); + static_assert(Dv % (2*WARP_SIZE) == 0, "Dv not divisible by 2*WARP_SIZE == 64."); + constexpr int nwarps = Dk / WARP_SIZE; const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; - __builtin_assume(tid < D); + __builtin_assume(tid < Dk); - __shared__ half KQ[ncols*D]; + __shared__ half KQ[ncols*Dk]; half2 * KQ2 = (half2 *) KQ; half kqmax[ncols]; @@ -94,9 +96,9 @@ static __global__ void flash_attn_vec_ext_f16( __syncthreads(); // Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers: - half2 Q_h2[ncols][D/(2*WARP_SIZE)]; - int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D/(sizeof(int)*QK8_1)]; - half2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1]; + half2 Q_h2[ncols][Dk/(2*WARP_SIZE)]; + int Q_i32[ncols][Dk/(sizeof(int)*QK8_1) == 0 ? 1 : Dk/(sizeof(int)*QK8_1)]; + half2 Q_ds[ncols][Dk/QK8_1 == 0 ? 1 : Dk/QK8_1]; if (Q_q8_1) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { @@ -107,18 +109,18 @@ static __global__ void flash_attn_vec_ext_f16( } // Reuse KQ as temporary storage for converting Q to q8_1: - int * tmp_q_i32 = (int *) &KQ[j*D]; - half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int)); + int * tmp_q_i32 = (int *) &KQ[j*Dk]; + half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + Dk/sizeof(int)); // Set memory to zero if out of bounds: if (ncols > 2 && ic0 + j >= ne01) { #pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dk/sizeof(int); i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; tmp_q_i32[i] = 0; } - if (threadIdx.x < D/QK8_1) { + if (threadIdx.x < Dk/QK8_1) { tmp_q_ds[threadIdx.x] = make_half2(0.0f, 0.0f); } continue; @@ -126,7 +128,7 @@ static __global__ void flash_attn_vec_ext_f16( const float * Q_f = (const float *) (Q + j*nb01); #pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dk/sizeof(int); i0 += WARP_SIZE) { quantize_q8_1_to_shared<half2>(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds); } } @@ -135,11 +137,11 @@ static __global__ void flash_attn_vec_ext_f16( #pragma unroll for (int j = 0; j < ncols; ++j) { - int * tmp_q_i32 = (int *) &KQ[j*D]; - half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int)); + int * tmp_q_i32 = (int *) &KQ[j*Dk]; + half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + Dk/sizeof(int)); #pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dk/sizeof(int); i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i]; @@ -154,7 +156,7 @@ static __global__ void flash_attn_vec_ext_f16( const float2 * Q_f2_j = (const float2 *) (Q + j*nb01); #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dk/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f); @@ -166,13 +168,13 @@ static __global__ void flash_attn_vec_ext_f16( #pragma unroll for (int j = 0; j < ncols; ++j) { - KQ[j*D + tid] = -HALF_MAX_HALF; + KQ[j*Dk + tid] = -HALF_MAX_HALF; } half2 VKQ[ncols] = {{0.0f, 0.0f}}; - const int k_start = parallel_blocks == 1 ? 0 : ip*D; - for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { + const int k_start = parallel_blocks == 1 ? 0 : ip*Dk; + for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*Dk) { // Calculate KQ tile and keep track of new maximum KQ values: // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression, @@ -186,10 +188,10 @@ static __global__ void flash_attn_vec_ext_f16( } #pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { + for (int i_KQ_0 = 0; i_KQ_0 < Dk; i_KQ_0 += nwarps) { const int i_KQ = i_KQ_0 + threadIdx.y; - if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) { + if ((i_KQ_0 + nwarps > Dk && i_KQ >= Dk) || (FATTN_KQ_STRIDE % Dk != 0 && k_VKQ_0 + i_KQ >= ne11)) { break; } @@ -209,7 +211,7 @@ static __global__ void flash_attn_vec_ext_f16( } if (threadIdx.x == 0) { - KQ[j*D + i_KQ] = sum; + KQ[j*Dk + i_KQ] = sum; } } } @@ -234,9 +236,9 @@ static __global__ void flash_attn_vec_ext_f16( const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j); kqmax[j] = kqmax_new_j; - const half val = hexp(KQ[j*D + tid] - kqmax[j]); + const half val = hexp(KQ[j*Dk + tid] - kqmax[j]); kqsum[j] = kqsum[j]*KQ_max_scale + val; - KQ[j*D + tid] = val; + KQ[j*Dk + tid] = val; VKQ[j] *= __half2half2(KQ_max_scale); } @@ -244,8 +246,8 @@ static __global__ void flash_attn_vec_ext_f16( __syncthreads(); #pragma unroll - for (int k0 = 0; k0 < D; k0 += 2) { - if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) { + for (int k0 = 0; k0 < Dv; k0 += 2) { + if (FATTN_KQ_STRIDE % Dv != 0 && k_VKQ_0 + k0 >= ne11) { break; } @@ -254,7 +256,7 @@ static __global__ void flash_attn_vec_ext_f16( reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid); #pragma unroll for (int j = 0; j < ncols; ++j) { - VKQ[j] += V_k*KQ2[j*(D/2) + k0/2]; + VKQ[j] += V_k*KQ2[j*(Dk/2) + k0/2]; } } @@ -285,27 +287,28 @@ static __global__ void flash_attn_vec_ext_f16( dst_val /= kqsum[j_VKQ]; } const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; - dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; + dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + tid] = dst_val; } if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); } + } #else NO_DEVICE_CODE; #endif // FP16_AVAILABLE } -template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap> +template <int Dk, int Dv, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap> void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V, use_softcap>; - constexpr bool need_f16_K = D != 128; - constexpr bool need_f16_V = D != 128 && D != 64; - launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); + constexpr int nwarps = Dk/WARP_SIZE; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, use_softcap>; + constexpr bool need_f16_K = Dk != 128 && Dk != 192; + constexpr bool need_f16_V = Dv != 128 && Dv != 64; + launch_fattn<Dk, Dv, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); } -template <int D, ggml_type type_K, ggml_type type_V> +template <int Dk, int Dv, ggml_type type_K, ggml_type type_V> void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; @@ -325,9 +328,9 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml constexpr int cols_per_block = 1; constexpr int parallel_blocks = 4; if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst); } return; } @@ -336,9 +339,9 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml constexpr int cols_per_block = 2; constexpr int parallel_blocks = 4; if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst); } return; } @@ -347,9 +350,9 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml constexpr int cols_per_block = 4; constexpr int parallel_blocks = 4; if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst); } return; } @@ -358,9 +361,9 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml constexpr int cols_per_block = 8; constexpr int parallel_blocks = 4; if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst); } return; } @@ -368,15 +371,19 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml constexpr int cols_per_block = 8; constexpr int parallel_blocks = 1; if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst); } } #define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V) \ template void ggml_cuda_flash_attn_ext_vec_f16_case \ - <D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + <D, D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +#define DECL_FATTN_VEC_F16_CASE_DKDV(Dk, Dv, type_K, type_V) \ + template void ggml_cuda_flash_attn_ext_vec_f16_case \ + <Dk, Dv, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0); extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1); @@ -435,3 +442,6 @@ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16); extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16); extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16); + +extern DECL_FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16); +extern DECL_FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 1aa88272..404afe2e 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -1,9 +1,9 @@ #include "common.cuh" #include "fattn-common.cuh" -template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap> // D == head size +template<int Dk, int Dv, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap> // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) -__launch_bounds__(D, 1) +__launch_bounds__(Dk, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_vec_ext_f32( const char * __restrict__ Q, @@ -42,14 +42,15 @@ static __global__ void flash_attn_vec_ext_f32( const int ne2, const int ne3) { // Skip unused kernel variants for faster compilation: - if (use_softcap && !(D == 128 || D == 256)) { + if constexpr (Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512)) { + if (use_softcap && !(Dk == 128 || Dk == 256)) { NO_DEVICE_CODE; return; } //In this kernel Q, K, V are matrices while i, j, k are matrix indices. - constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32<D>(type_K); + constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32<Dk>(type_K); constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V); @@ -64,15 +65,16 @@ static __global__ void flash_attn_vec_ext_f32( const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); - static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); - constexpr int nwarps = D / WARP_SIZE; + static_assert(Dk % (2*WARP_SIZE) == 0, "Dk not divisible by 2*WARP_SIZE == 64."); + static_assert(Dv % (2*WARP_SIZE) == 0, "Dv not divisible by 2*WARP_SIZE == 64."); + constexpr int nwarps = Dk / WARP_SIZE; const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; - __builtin_assume(tid < D); + __builtin_assume(tid < Dk); - __shared__ float KQ[ncols*D]; + __shared__ float KQ[ncols*Dk]; #pragma unroll for (int j = 0; j < ncols; ++j) { - KQ[j*D + tid] = -FLT_MAX/2.0f; + KQ[j*Dk + tid] = -FLT_MAX/2.0f; } float kqmax[ncols]; @@ -94,9 +96,9 @@ static __global__ void flash_attn_vec_ext_f32( __syncthreads(); // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers: - float2 Q_f2[ncols][D/(2*WARP_SIZE)]; - int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D >= D/(sizeof(int)*QK8_1)]; - float2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1]; + float2 Q_f2[ncols][Dk/(2*WARP_SIZE)]; + int Q_i32[ncols][Dk/(sizeof(int)*QK8_1) == 0 ? 1 : Dk >= Dk/(sizeof(int)*QK8_1)]; + float2 Q_ds[ncols][Dk/QK8_1 == 0 ? 1 : Dk/QK8_1]; if (Q_q8_1) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { @@ -107,18 +109,18 @@ static __global__ void flash_attn_vec_ext_f32( } // Reuse KQ as temporary storage for converting Q to q8_1: - int * tmp_q_i32 = (int *) &KQ[j*D]; - float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); + int * tmp_q_i32 = (int *) &KQ[j*Dk]; + float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + Dk/sizeof(int)); // Set memory to zero if out of bounds: if (ncols > 2 && ic0 + j >= ne01) { #pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dk/sizeof(int); i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; tmp_q_i32[i] = 0; } - if (threadIdx.x < D/QK8_1) { + if (threadIdx.x < Dk/QK8_1) { tmp_q_ds[threadIdx.x] = make_float2(0.0f, 0.0f); } continue; @@ -126,7 +128,7 @@ static __global__ void flash_attn_vec_ext_f32( const float * Q_f = (const float *) (Q + j*nb01); #pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dk/sizeof(int); i0 += WARP_SIZE) { quantize_q8_1_to_shared<float2>(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds); } } @@ -135,11 +137,11 @@ static __global__ void flash_attn_vec_ext_f32( #pragma unroll for (int j = 0; j < ncols; ++j) { - int * tmp_q_i32 = (int *) &KQ[j*D]; - float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); + int * tmp_q_i32 = (int *) &KQ[j*Dk]; + float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + Dk/sizeof(int)); #pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dk/sizeof(int); i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i]; @@ -153,7 +155,7 @@ static __global__ void flash_attn_vec_ext_f32( for (int j = 0; j < ncols; ++j) { const float2 * Q_f2_j = (const float2 *) (Q + j*nb01); #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dk/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; Q_f2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f); @@ -165,8 +167,8 @@ static __global__ void flash_attn_vec_ext_f32( float VKQ[ncols] = {0.0f}; - const int k_start = parallel_blocks == 1 ? 0 : ip*D; - for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { + const int k_start = parallel_blocks == 1 ? 0 : ip*Dk; + for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*Dk) { // Calculate KQ tile and keep track of new maximum KQ values: float kqmax_new_arr[ncols]; @@ -176,10 +178,10 @@ static __global__ void flash_attn_vec_ext_f32( } #pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { + for (int i_KQ_0 = 0; i_KQ_0 < Dk; i_KQ_0 += nwarps) { const int i_KQ = i_KQ_0 + threadIdx.y; - if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) { + if ((i_KQ_0 + nwarps > Dk && i_KQ >= Dk) || (FATTN_KQ_STRIDE % Dk != 0 && k_VKQ_0 + i_KQ >= ne11)) { break; } @@ -195,7 +197,7 @@ static __global__ void flash_attn_vec_ext_f32( kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum); if (threadIdx.x == 0) { - KQ[j*D + i_KQ] = sum; + KQ[j*Dk + i_KQ] = sum; } } } @@ -220,9 +222,9 @@ static __global__ void flash_attn_vec_ext_f32( const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j); kqmax[j] = kqmax_new_j; - const float val = expf(KQ[j*D + tid] - kqmax[j]); + const float val = expf(KQ[j*Dk + tid] - kqmax[j]); kqsum[j] = kqsum[j]*KQ_max_scale + val; - KQ[j*D + tid] = val; + KQ[j*Dk + tid] = val; VKQ[j] *= KQ_max_scale; } @@ -230,15 +232,15 @@ static __global__ void flash_attn_vec_ext_f32( __syncthreads(); #pragma unroll - for (int k = 0; k < D; ++k) { - if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k >= ne11) { + for (int k = 0; k < Dv; ++k) { + if (FATTN_KQ_STRIDE % Dv != 0 && k_VKQ_0 + k >= ne11) { break; } const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid); #pragma unroll for (int j = 0; j < ncols; ++j) { - VKQ[j] += V_ki*KQ[j*D + k]; + VKQ[j] += V_ki*KQ[j*Dk + k]; } } @@ -269,24 +271,25 @@ static __global__ void flash_attn_vec_ext_f32( dst_val /= kqsum[j_VKQ]; } const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; - dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; + dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + tid] = dst_val; } if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); } + } } -template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap> +template <int Dk, int Dv, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap> void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_softcap>; - constexpr bool need_f16_K = D != 128; - constexpr bool need_f16_V = D != 128 && D != 64; - launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); + constexpr int nwarps = Dk/WARP_SIZE; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, use_softcap>; + constexpr bool need_f16_K = Dk != 128 && Dk != 192; + constexpr bool need_f16_V = Dv != 128 && Dv != 64; + launch_fattn<Dk, Dv, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); } -template <int D, ggml_type type_K, ggml_type type_V> +template <int Dk, int Dv, ggml_type type_K, ggml_type type_V> void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; @@ -303,9 +306,9 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml constexpr int cols_per_block = 1; constexpr int parallel_blocks = 4; if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst); } return; } @@ -314,9 +317,9 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml constexpr int cols_per_block = 2; constexpr int parallel_blocks = 4; if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst); } return; } @@ -325,9 +328,9 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml constexpr int cols_per_block = 4; constexpr int parallel_blocks = 4; if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst); } return; } @@ -336,9 +339,9 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml constexpr int cols_per_block = 8; constexpr int parallel_blocks = 4; if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst); } return; } @@ -346,15 +349,19 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml constexpr int cols_per_block = 8; constexpr int parallel_blocks = 1; if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst); } } #define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V) \ template void ggml_cuda_flash_attn_ext_vec_f32_case \ - <D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + <D, D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +#define DECL_FATTN_VEC_F32_CASE_DKDV(Dk, Dv, type_K, type_V) \ + template void ggml_cuda_flash_attn_ext_vec_f32_case \ + <Dk, Dv, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0); extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1); @@ -406,3 +413,6 @@ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16); extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16); extern DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16); + +extern DECL_FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16); +extern DECL_FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index efe78a2f..c5ffc7d1 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -5,8 +5,8 @@ #include <mma.h> #endif // FP16_MMA_AVAILABLE -// D == head size, VKQ_stride == num VKQ rows calculated in parallel: -template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_softcap> +// Dk == K head size, Dv = V head size, VKQ_stride == num VKQ rows calculated in parallel: +template<int Dk, int Dv, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_softcap> #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -47,8 +47,9 @@ static __global__ void flash_attn_ext_f16( const int ne2, const int ne3) { #ifdef FP16_MMA_AVAILABLE + static_assert(Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512)); // Skip unused kernel variants for faster compilation: - if (use_softcap && !(D == 128 || D == 256)) { + if (use_softcap && !(Dk == 128 || Dk == 256)) { NO_DEVICE_CODE; return; } @@ -58,11 +59,11 @@ static __global__ void flash_attn_ext_f16( const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. - static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); + static_assert(Dk <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); constexpr int frag_m = ncols == 8 ? 32 : 16; constexpr int frag_n = ncols == 8 ? 8 : 16; - static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); + static_assert(Dk % frag_m == 0, "If ncols == 8 then Dk % frag_m must be 0."); typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K; typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V; typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b; @@ -74,30 +75,32 @@ static __global__ void flash_attn_ext_f16( static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps."); // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts: - constexpr int D_padded = D + 8; + constexpr int Dk_padded = Dk + 8; + constexpr int Dv_padded = Dv + 8; constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); - const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape + const half * V_h = (const half *) (V + nb22*(blockIdx.y / gqa_ratio)); // K and V have same shape const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0; const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2); - const int stride_Q = nb01 / sizeof(float); - const int stride_KV = nb11 / sizeof(half); + const int stride_Q = nb01 / sizeof(float); + const int stride_K = nb11 / sizeof(half); + const int stride_V = nb21 / sizeof(half); const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); const half2 slope2 = make_half2(slopef, slopef); const half2 softcap_2 = make_half2(softcap, softcap); - frag_b Q_b[D/16][ncols/frag_n]; + frag_b Q_b[Dk/16][ncols/frag_n]; // A single buffer for temporarily holding tiles of KQ and VKQ parts: constexpr int mem_KQ = ncols*kqs_padded*kqar; - constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded; + constexpr int mem_VKQ_parts = VKQ_ratio*ncols*Dv_padded; __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; float * KQ_f = (float *) KQ; half2 * KQ2 = (half2 *) KQ; @@ -120,18 +123,18 @@ static __global__ void flash_attn_ext_f16( KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF); } - __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. + __shared__ half VKQ[ncols*Dv_padded]; // Accumulator for final VKQ slice. half2 * VKQ2 = (half2 *) VKQ; #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D/2 && i >= D/2) { + if (i0 + WARP_SIZE > Dv/2 && i >= Dv/2) { break; } - VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f); + VKQ2[j*(Dv_padded/2) + i] = make_half2(0.0f, 0.0f); } } @@ -140,12 +143,12 @@ static __global__ void flash_attn_ext_f16( for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dk; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D && i >= D) { + if (i0 + WARP_SIZE > Dk && i >= Dk) { break; } - KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + KQ[j*Dk_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; } } @@ -153,10 +156,10 @@ static __global__ void flash_attn_ext_f16( // Load Q into tensor core fragments/registers since it will be used frequently: #pragma unroll - for (int i0 = 0; i0 < D; i0 += 16) { + for (int i0 = 0; i0 < Dk; i0 += 16) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); + nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*Dk_padded + i0, Dk_padded); } } @@ -173,9 +176,9 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); } #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk; k_KQ_0 += 16) { frag_a_K K_a; - nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); + nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_K + k_KQ_0, stride_K); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); @@ -309,9 +312,9 @@ static __global__ void flash_attn_ext_f16( } } - frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n]; + frag_c_VKQ VKQ_c[Dv/VKQ_stride][ncols/frag_n]; #pragma unroll - for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { + for (int i_VKQ_0 = 0; i_VKQ_0 < Dv; i_VKQ_0 += VKQ_stride) { #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); @@ -322,7 +325,7 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + (threadIdx.y % VKQ_ratio)*16; frag_a_V v_a; - nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); + nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_V + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_V); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); @@ -332,15 +335,15 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); - const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded); + const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*Dk_padded); #pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { + for (int i_KQ_0 = 0; i_KQ_0 < Dk; i_KQ_0 += VKQ_stride) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { nvcuda::wmma::store_matrix_sync( - KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), + KQ + offset_k + j0*Dk_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], - D_padded, nvcuda::wmma::mem_col_major); + Dk_padded, nvcuda::wmma::mem_col_major); } } @@ -358,18 +361,18 @@ static __global__ void flash_attn_ext_f16( } #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D/2 && i >= D/2) { + if (i0 + WARP_SIZE > Dv/2 && i >= Dv/2) { break; } half2 VKQ_add = make_half2(0.0f, 0.0f); #pragma unroll for (int l = 0; l < VKQ_ratio; ++l) { - VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i]; + VKQ_add += KQ2[l*(ncols*Dk_padded/2) + j*(Dk_padded/2) + i]; } - VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add; + VKQ2[j*(Dv_padded/2) + i] = VKQ_scale*VKQ2[j*(Dv_padded/2) + i] + VKQ_add; } } @@ -392,16 +395,16 @@ static __global__ void flash_attn_ext_f16( } #pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dv; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D && i >= D) { + if (i0 + WARP_SIZE > Dv && i >= Dv) { break; } - float dst_val = VKQ[j_VKQ*D_padded + i]; + float dst_val = VKQ[j_VKQ*Dv_padded + i]; if (parallel_blocks == 1) { dst_val /= KQ_rowsum_j; } - dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val; + dst[j_dst*gridDim.y*Dv + blockIdx.y*Dv + i] = dst_val; } if (parallel_blocks == 1 || threadIdx.x != 0) { @@ -446,13 +449,13 @@ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); -template <int D, int cols_per_block, typename KQ_acc_t> +template <int Dk, int Dv, int cols_per_block, typename KQ_acc_t> void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; constexpr int nwarps = 4; - constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16; + constexpr int frag_m = cols_per_block == 8 && Dk % 32 == 0 ? 32 : 16; const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; @@ -462,29 +465,33 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm if (4*blocks_num_pb1 < 2*nsm) { constexpr int parallel_blocks = 4; fattn_kernel_t fattn_kernel = softcap == 0.0f ? - flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, false> : - flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, true>; - launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + flash_attn_ext_f16<Dk, Dv, cols_per_block, nwarps, get_VKQ_stride(Dv, nwarps, frag_m), parallel_blocks, KQ_acc_t, false> : + flash_attn_ext_f16<Dk, Dv, cols_per_block, nwarps, get_VKQ_stride(Dv, nwarps, frag_m), parallel_blocks, KQ_acc_t, true>; + launch_fattn<Dk, Dv, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); return; } if (2*blocks_num_pb1 < 2*nsm) { constexpr int parallel_blocks = 2; fattn_kernel_t fattn_kernel = softcap == 0.0f ? - flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, false> : - flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, true>; - launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + flash_attn_ext_f16<Dk, Dv, cols_per_block, nwarps, get_VKQ_stride(Dv, nwarps, frag_m), parallel_blocks, KQ_acc_t, false> : + flash_attn_ext_f16<Dk, Dv, cols_per_block, nwarps, get_VKQ_stride(Dv, nwarps, frag_m), parallel_blocks, KQ_acc_t, true>; + launch_fattn<Dk, Dv, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); return; } constexpr int parallel_blocks = 1; fattn_kernel_t fattn_kernel = softcap == 0.0f ? - flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, false> : - flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, true>; - launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + flash_attn_ext_f16<Dk, Dv, cols_per_block, nwarps, get_VKQ_stride(Dv, nwarps, frag_m), parallel_blocks, KQ_acc_t, false> : + flash_attn_ext_f16<Dk, Dv, cols_per_block, nwarps, get_VKQ_stride(Dv, nwarps, frag_m), parallel_blocks, KQ_acc_t, true>; + launch_fattn<Dk, Dv, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } #define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t) \ template void ggml_cuda_flash_attn_ext_wmma_f16_case \ - <D, cols_per_block, KQ_acc_t>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + <D, D, cols_per_block, KQ_acc_t>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +#define DECL_FATTN_WMMA_F16_CASE_DKDV(Dk, Dv, cols_per_block, KQ_acc_t) \ + template void ggml_cuda_flash_attn_ext_wmma_f16_case \ + <Dk, Dv, cols_per_block, KQ_acc_t>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ extern DECL_FATTN_WMMA_F16_CASE( 64, 16, float); extern DECL_FATTN_WMMA_F16_CASE( 80, 16, float); @@ -518,3 +525,7 @@ extern DECL_FATTN_WMMA_F16_CASE( 96, 32, half); extern DECL_FATTN_WMMA_F16_CASE(112, 32, half); extern DECL_FATTN_WMMA_F16_CASE(128, 32, half); extern DECL_FATTN_WMMA_F16_CASE(256, 16, half); + +extern DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 8, half); +extern DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 16, half); +extern DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 32, half); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index c15d6c81..6329908d 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -12,6 +12,14 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * V = dst->src[2]; + + if (Q->ne[0] != V->ne[0]) { + if (!((Q->ne[0] == 192 && V->ne[0] == 128) || (Q->ne[0] == 576 && V->ne[0] == 512))) { + fprintf(stderr, "======================= %s: Unhandled head size combination %d, %d\n", __func__, (int)Q->ne[0], (int)V->ne[0]); + GGML_ABORT("fatal error"); + } + } const int32_t precision = KQV->op_params[3]; @@ -20,22 +28,25 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g constexpr int cols_per_block = 16; switch (Q->ne[0]) { case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, float>(ctx, dst); break; case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, float>(ctx, dst); break; case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, float>(ctx, dst); break; case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, float>(ctx, dst); break; case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, float>(ctx, dst); break; case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, float>(ctx, dst); + break; + case 192: + ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, float>(ctx, dst); break; default: fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); @@ -46,19 +57,22 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g constexpr int cols_per_block = 32; switch (Q->ne[0]) { case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, float>(ctx, dst); break; case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, float>(ctx, dst); break; case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, float>(ctx, dst); break; case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, float>(ctx, dst); break; case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, float>(ctx, dst); + break; + case 192: + ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, float>(ctx, dst); break; // case 256: // ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); @@ -76,16 +90,19 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g constexpr int cols_per_block = 8; switch (Q->ne[0]) { case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, half>(ctx, dst); break; case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, half>(ctx, dst); break; case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, half>(ctx, dst); + break; + case 192: + ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, half>(ctx, dst); break; case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, half>(ctx, dst); break; default: fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); @@ -99,22 +116,25 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g constexpr int cols_per_block = 16; switch (Q->ne[0]) { case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, half>(ctx, dst); break; case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, half>(ctx, dst); break; case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, half>(ctx, dst); break; case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, half>(ctx, dst); break; case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, half>(ctx, dst); + break; + case 192: + ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, half>(ctx, dst); break; case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, half>(ctx, dst); break; default: fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); @@ -127,22 +147,25 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g constexpr int cols_per_block = 32; switch (Q->ne[0]) { case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, half>(ctx, dst); break; case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, half>(ctx, dst); break; case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, half>(ctx, dst); break; case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, half>(ctx, dst); break; case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, half>(ctx, dst); + break; + case 192: + ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, half>(ctx, dst); break; case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, half>(ctx, dst); break; default: fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); @@ -152,7 +175,13 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g } #define FATTN_VEC_F16_CASE(D, type_K, type_V) \ if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \ - ggml_cuda_flash_attn_ext_vec_f16_case<D, type_K, type_V>(ctx, dst); \ + ggml_cuda_flash_attn_ext_vec_f16_case<D, D, type_K, type_V>(ctx, dst); \ + return; \ + } \ + +#define FATTN_VEC_F16_CASE_DKDV(Dk, Dv, type_K, type_V) \ + if (Q->ne[0] == (Dk) && V->ne[0] == Dv && K->type == (type_K) && V->type == (type_V)) { \ + ggml_cuda_flash_attn_ext_vec_f16_case<Dk, Dv, type_K, type_V>(ctx, dst); \ return; \ } \ @@ -218,6 +247,9 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0) FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q6_0) FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0) + + FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) #else FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) @@ -231,14 +263,24 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL) FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0) FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0) + + FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) + #endif // GGML_CUDA_FA_ALL_QUANTS - on_no_fattn_vec_case(Q->ne[0]); + on_no_fattn_vec_case(Q->ne[0], V->ne[0]); } #define FATTN_VEC_F32_CASE(D, type_K, type_V) \ if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \ - ggml_cuda_flash_attn_ext_vec_f32_case<D, type_K, type_V>(ctx, dst); \ + ggml_cuda_flash_attn_ext_vec_f32_case<D, D, type_K, type_V>(ctx, dst); \ + return; \ + } \ + +#define FATTN_VEC_F32_CASE_DKDV(Dk, Dv, type_K, type_V) \ + if (Q->ne[0] == (Dk) && V->ne[0] == Dv && K->type == (type_K) && V->type == (type_V)) { \ + ggml_cuda_flash_attn_ext_vec_f32_case<Dk, Dv, type_K, type_V>(ctx, dst); \ return; \ } \ @@ -298,6 +340,9 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) + + FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) #else FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) @@ -306,9 +351,12 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) + + FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) #endif // GGML_CUDA_FA_ALL_QUANTS - on_no_fattn_vec_case(Q->ne[0]); + on_no_fattn_vec_case(Q->ne[0], V->ne[0]); } void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs192-f16-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs192-f16-f16.cu new file mode 100644 index 00000000..7dda0133 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs192-f16-f16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f16.cuh" + +DECL_FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs192-q8_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs192-q8_0-q8_0.cu new file mode 100644 index 00000000..740ac37d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs192-q8_0-q8_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f16.cuh" + +DECL_FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs192-f16-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs192-f16-f16.cu new file mode 100644 index 00000000..1ea24302 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs192-f16-f16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f32.cuh" + +DECL_FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs192-q8_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs192-q8_0-q8_0.cu new file mode 100644 index 00000000..6be4d042 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs192-q8_0-q8_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f32.cuh" + +DECL_FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu index 2d94e65c..334e1deb 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu @@ -8,3 +8,5 @@ DECL_FATTN_WMMA_F16_CASE(96, 16, float); DECL_FATTN_WMMA_F16_CASE(112, 16, float); DECL_FATTN_WMMA_F16_CASE(128, 16, float); DECL_FATTN_WMMA_F16_CASE(256, 16, float); + +DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 16, float); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu index c3d9df3c..1faf3c9b 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu @@ -7,3 +7,5 @@ DECL_FATTN_WMMA_F16_CASE(80, 32, float); DECL_FATTN_WMMA_F16_CASE(96, 32, float); DECL_FATTN_WMMA_F16_CASE(112, 32, float); DECL_FATTN_WMMA_F16_CASE(128, 32, float); + +DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 32, float); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu index bb680e40..48973618 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu @@ -8,3 +8,5 @@ DECL_FATTN_WMMA_F16_CASE(96, 16, half); DECL_FATTN_WMMA_F16_CASE(112, 16, half); DECL_FATTN_WMMA_F16_CASE(128, 16, half); DECL_FATTN_WMMA_F16_CASE(256, 16, half); + +DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 16, half); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu index 073f71b1..ed92963e 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu @@ -8,3 +8,5 @@ DECL_FATTN_WMMA_F16_CASE(96, 32, half); DECL_FATTN_WMMA_F16_CASE(112, 32, half); DECL_FATTN_WMMA_F16_CASE(128, 32, half); DECL_FATTN_WMMA_F16_CASE(256, 32, half); + +DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 32, half); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu index d30710c5..4e221003 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu @@ -6,3 +6,5 @@ DECL_FATTN_WMMA_F16_CASE(64, 8, half); DECL_FATTN_WMMA_F16_CASE(96, 8, half); DECL_FATTN_WMMA_F16_CASE(128, 8, half); DECL_FATTN_WMMA_F16_CASE(256, 8, half); + +DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 8, half); |