diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2024-05-11 10:32:41 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-11 10:32:41 +0300 |
commit | 9cb317f77e53067f7a138cc89ef7657148eae8e6 (patch) | |
tree | 3ba1d2d80d1d7c8b4ab01f6396a3febaae26e91b /ggml-cuda | |
parent | e849648888a11de13aaaa4cb2eda3f5a9c7b444d (diff) |
ggml : full ALiBi support (#7192)
* ggml : full ALiBi support
* ggml : update ggml_soft_max_ext() CUDA, SYCL
* ggml : ggml_flash_attn_ext() support ALiBi (CPU)
* ggml : ggml_flash_attn_ext() support ALiBi (Metal)
* ggml : fix warning
* ggml : ggml_flash_attn_ext() support ALiBi (CUDA)
ggml-ci
* ggml : fix assert message
* vulkan : add dev notes
* ggml : require mask when using ALiBi
ggml-ci
* convert : fix convert for refact models
Diffstat (limited to 'ggml-cuda')
-rw-r--r-- | ggml-cuda/alibi.cu | 63 | ||||
-rw-r--r-- | ggml-cuda/alibi.cuh | 5 | ||||
-rw-r--r-- | ggml-cuda/fattn.cu | 72 | ||||
-rw-r--r-- | ggml-cuda/softmax.cu | 55 |
4 files changed, 83 insertions, 112 deletions
diff --git a/ggml-cuda/alibi.cu b/ggml-cuda/alibi.cu deleted file mode 100644 index 6c7f1fd9..00000000 --- a/ggml-cuda/alibi.cu +++ /dev/null @@ -1,63 +0,0 @@ -#include "alibi.cuh" - -static __global__ void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows, - const int n_heads_log2_floor, const float m0, const float m1) { - const int col = blockDim.x*blockIdx.x + threadIdx.x; - - if (col >= ncols) { - return; - } - - const int row = blockDim.y*blockIdx.y + threadIdx.y; - const int i = row*ncols + col; - - const int k = row/k_rows; - - float m_k; - if (k < n_heads_log2_floor) { - m_k = powf(m0, k + 1); - } else { - m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1); - } - - dst[i] = col * m_k + x[i]; -} - -static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, - const int k_rows, const int n_heads_log2_floor, const float m0, - const float m1, cudaStream_t stream) { - const dim3 block_dims(CUDA_ALIBI_BLOCK_SIZE, 1, 1); - const int num_blocks_x = (ncols + CUDA_ALIBI_BLOCK_SIZE - 1) / (CUDA_ALIBI_BLOCK_SIZE); - const dim3 block_nums(num_blocks_x, nrows, 1); - alibi_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, k_rows, n_heads_log2_floor, m0, m1); -} - -void ggml_cuda_op_alibi(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - float * dst_d = (float *)dst->data; - cudaStream_t stream = ctx.stream(); - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - const int64_t nrows = ggml_nrows(src0); - - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_head = ((int32_t *) dst->op_params)[1]; - float max_bias; - memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - - //GGML_ASSERT(ne01 + n_past == ne00); - GGML_ASSERT(n_head == ne02); - - const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); - - const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); - - alibi_f32_cuda(src0_d, dst_d, ne00, nrows, ne01, n_heads_log2_floor, m0, m1, stream); -} diff --git a/ggml-cuda/alibi.cuh b/ggml-cuda/alibi.cuh deleted file mode 100644 index 630adfc7..00000000 --- a/ggml-cuda/alibi.cuh +++ /dev/null @@ -1,5 +0,0 @@ -#include "common.cuh" - -#define CUDA_ALIBI_BLOCK_SIZE 32 - -void ggml_cuda_op_alibi(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 7c486f48..ac5d6672 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -23,6 +23,10 @@ static __global__ void flash_attn_vec_ext_f16( float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, const int ne00, const int ne01, const int ne02, @@ -58,6 +62,18 @@ static __global__ void flash_attn_vec_ext_f16( const int stride_KV = nb11 / sizeof(half); const int stride_KV2 = nb11 / sizeof(half2); + half slopeh = __float2half(1.0f); + + // ALiBi + if (max_bias > 0.0f) { + const int h = blockIdx.y; + + const float base = h < n_head_log2 ? m0 : m1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slopeh = __float2half(powf(base, exph)); + } + static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); constexpr int nwarps = D / WARP_SIZE; const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; @@ -141,7 +157,7 @@ static __global__ void flash_attn_vec_ext_f16( for (int j = 0; j < ncols; ++j) { sum2[j] = warp_reduce_sum(sum2[j]); half sum = __low2half(sum2[j]) + __high2half(sum2[j]); - sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); + sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); if (ncols == 1) { kqmax_new = ggml_cuda_hmax(kqmax_new, sum); @@ -249,6 +265,10 @@ static __global__ void flash_attn_ext_f16( float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, const int ne00, const int ne01, const int ne02, @@ -305,6 +325,20 @@ static __global__ void flash_attn_ext_f16( const int stride_Q = nb01 / sizeof(float); const int stride_KV = nb11 / sizeof(half); + half slopeh = __float2half(1.0f); + half2 slope2 = make_half2(1.0f, 1.0f); + + // ALiBi + if (max_bias > 0.0f) { + const int h = blockIdx.y; + + const float base = h < n_head_log2 ? m0 : m1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slopeh = __float2half(powf(base, exph)); + slope2 = make_half2(slopeh, slopeh); + } + frag_b Q_b[D/16][ncols/frag_n]; // A single buffer for temporarily holding tiles of KQ and VKQ parts: @@ -421,7 +455,7 @@ static __global__ void flash_attn_ext_f16( for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { const int k = k0 + threadIdx.x; - KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; + KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]); } KQ_max_new = warp_reduce_max(KQ_max_new); @@ -464,7 +498,7 @@ static __global__ void flash_attn_ext_f16( for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { const int k = k0 + threadIdx.x; - KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); } KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); @@ -710,8 +744,17 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_ const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]); const int shmem = 0; - float scale; - memcpy(&scale, KQV->op_params, sizeof(float)); + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); + + const uint32_t n_head = Q->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks> <<<blocks_num, block_dim, shmem, main_stream>>> ( @@ -720,7 +763,7 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_ (const char *) V->data, mask ? ((const char *) mask->data) : nullptr, parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, + scale, max_bias, m0, m1, n_head_log2, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, @@ -761,8 +804,17 @@ template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename K const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); const int shmem = 0; - float scale; - memcpy(&scale, KQV->op_params, sizeof(float)); + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); + + const uint32_t n_head = Q->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t> <<<blocks_num, block_dim, shmem, main_stream>>> ( @@ -771,7 +823,7 @@ template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename K (const char *) V->data, mask ? ((const char *) mask->data) : nullptr, (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, + scale, max_bias, m0, m1, n_head_log2, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, @@ -837,7 +889,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; - const int32_t precision = KQV->op_params[1]; + const int32_t precision = KQV->op_params[2]; if (!fp16_mma_available(cc)) { GGML_ASSERT(precision == GGML_PREC_DEFAULT); diff --git a/ggml-cuda/softmax.cu b/ggml-cuda/softmax.cu index 6ed22599..ca85285a 100644 --- a/ggml-cuda/softmax.cu +++ b/ggml-cuda/softmax.cu @@ -11,7 +11,7 @@ __device__ float __forceinline__ t2f32<half>(half val) { } template <bool vals_smem, int ncols_template, int block_size_template, typename T> -static __global__ void soft_max_f32(const float * x, const T * mask, const T * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { +static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; const int tid = threadIdx.x; @@ -23,16 +23,16 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p const int warp_id = threadIdx.x / WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE; - float slope = 0.0f; + float slope = 1.0f; // ALiBi if (max_bias > 0.0f) { const int h = rowx/nrows_y; // head index const float base = h < n_head_log2 ? m0 : m1; - const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - slope = powf(base, exp); + slope = powf(base, exph); } extern __shared__ float data_soft_max_f32[]; @@ -53,7 +53,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p const int64_t ix = (int64_t)rowx*ncols + col; const int64_t iy = (int64_t)rowy*ncols + col; - const float val = x[ix]*scale + (mask ? t2f32(mask[iy]) : 0.0f) + (pos ? slope*t2f32(pos[col]) : 0.0f); + const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f); vals[col] = val; max_val = max(max_val, val); @@ -125,7 +125,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p } template<typename T> -static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { +static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); @@ -133,8 +133,8 @@ static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, fl const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); - const uint32_t n_head_kv = nrows_x/nrows_y; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); + const uint32_t n_head = nrows_x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); @@ -142,43 +142,42 @@ static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, fl if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) { switch (ncols_x) { case 32: - soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 64: - soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 128: - soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 256: - soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 512: - soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 1024: - soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 2048: - soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 4096: - soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; default: - soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; } } else { const size_t shmem_low = WARP_SIZE*sizeof(float); - soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); } } void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; - const ggml_tensor * src2 = dst->src[2]; const float * src0_d = (const float *)src0->data; const void * src1_d = src1 ? (const void *)src1->data : nullptr; @@ -190,7 +189,6 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional - GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); @@ -202,26 +200,15 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); - // positions tensor - void * src2_d = nullptr; - - const bool use_src2 = src2 != nullptr; - - if (use_src2) { - src2_d = (void *)src2->data; - } - - const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16); + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); if (use_f16) { const half * src1_dd = (const half *)src1_d; - const half * src2_dd = (const half *)src2_d; - soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); } else { const float * src1_dd = (const float *)src1_d; - const float * src2_dd = (const float *)src2_d; - soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); } } |