diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2024-09-28 17:41:21 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-28 17:41:21 +0300 |
commit | 7abcc6cc0b0a48b780bb0877e4720c46a7e3c255 (patch) | |
tree | 0fbe1e7c3d02462ba3972f9be88a39ab9a3c0c30 | |
parent | 737514fd814d944f8ce965620293a16e5e8a285d (diff) |
CUDA non-contiguous RoPE (#66)
In this way we can avoid the Q, K, V copies being made
after multiplication with the QKV tensor in, e.g., Phi-3.5-mini.
This results in a 6-7% speedup of PP-512(Phi-3.5-mini)
on CUDA (RTX-4080)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/ggml-cuda.cu | 5 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/rope.cu | 247 | ||||
-rw-r--r-- | src/llama.cpp | 11 |
3 files changed, 231 insertions, 32 deletions
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 966c91c0..3d24cc6f 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2924,9 +2924,10 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_CAP_MAX: - return true; case GGML_OP_ROPE: - return ggml_is_contiguous(op->src[0]); + return true; + //case GGML_OP_ROPE: + // return ggml_is_contiguous(op->src[0]); case GGML_OP_IM2COL: case GGML_OP_POOL_2D: case GGML_OP_SUM_ROWS: diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index 99ec1dd9..42ca72af 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -69,6 +69,49 @@ static __global__ void rope_norm( } template<typename T, bool has_ff> +static __global__ void rope_norm_nc( + const T * x, T * dst, int ne0, int ne1, int nb1, int nb2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, + float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) { + const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); + + if (i0 >= ne0) { + return; + } + + const int row = blockDim.x*blockIdx.x + threadIdx.x; + const int j2 = row/ne1; + const int j1 = row%ne1; + const T * xx = x + j1*nb1 + j2*nb2; + + if (i0 >= n_dims) { + const int i = row*ne0 + i0; + + dst[i + 0] = xx[i0 + 0]; + dst[i + 1] = xx[i0 + 1]; + + return; + } + + const int i = row*ne0 + i0; + const int i2 = row/p_delta_rows; + + const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); + + const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; + + float cos_theta; + float sin_theta; + + rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + + const float x0 = xx[i0 + 0]; + const float x1 = xx[i0 + 1]; + + dst[i + 0] = x0*cos_theta - x1*sin_theta; + dst[i + 1] = x0*sin_theta + x1*cos_theta; +} + +template<typename T, bool has_ff> static __global__ void rope_neox( const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) { @@ -108,6 +151,49 @@ static __global__ void rope_neox( dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta; } +template<typename T, bool has_ff> +static __global__ void rope_neox_nc( + const T * x, T * dst, int ne0, int ne1, int nb1, int nb2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, + float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) { + const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); + + if (i0 >= ne0) { + return; + } + + const int row = blockDim.x*blockIdx.x + threadIdx.x; + const int j2 = row/ne1; + const int j1 = row%ne1; + const T * xx = x + j1*nb1 + j2*nb2; + + if (i0 >= n_dims) { + const int i = row*ne0 + i0; + + dst[i + 0] = xx[i0 + 0]; + dst[i + 1] = xx[i0 + 1]; + + return; + } + + const int i = row*ne0 + i0/2; + const int i2 = row/p_delta_rows; + + const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); + + const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; + + float cos_theta; + float sin_theta; + + rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + + const float x0 = xx[i0/2 + 0]; + const float x1 = xx[i0/2 + n_dims/2]; + + dst[i + 0] = x0*cos_theta - x1*sin_theta; + dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta; +} + template<typename T> static void rope_norm_cuda( const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, @@ -133,6 +219,30 @@ static void rope_norm_cuda( } template<typename T> +static void rope_norm_nc_cuda( + const T * x, T * dst, int ne0, int ne1, int nb1, int nb2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { + GGML_ASSERT(ne0 % 2 == 0); + const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); + const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const dim3 block_nums(nr, n_blocks_x, 1); + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + + if (freq_factors == nullptr) { + rope_norm_nc<T, false><<<block_nums, block_dims, 0, stream>>>( + x, dst, ne0, ne1, nb1, nb2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors + ); + } else { + rope_norm_nc<T, true><<<block_nums, block_dims, 0, stream>>>( + x, dst, ne0, ne1, nb1, nb2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors + ); + } +} + +template<typename T> static void rope_neox_cuda( const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { @@ -156,6 +266,30 @@ static void rope_neox_cuda( } } +template<typename T> +static void rope_neox_nc_cuda( + const T * x, T * dst, int ne0, int ne1, int nb1, int nb2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { + GGML_ASSERT(ne0 % 2 == 0); + const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); + const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const dim3 block_nums(nr, n_blocks_x, 1); + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + + if (freq_factors == nullptr) { + rope_neox_nc<T, false><<<block_nums, block_dims, 0, stream>>>( + x, dst, ne0, ne1, nb1, nb2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors + ); + } else { + rope_neox_nc<T, true><<<block_nums, block_dims, 0, stream>>>( + x, dst, ne0, ne1, nb1, nb2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors + ); + } +} + static void rope_norm_cuda_f16( const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { @@ -170,6 +304,20 @@ static void rope_norm_cuda_f32( rope_norm_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } +static void rope_norm_cuda_nc_f16( + const half * x, half * dst, int ne0, int ne1, int nb1, int nb2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { + + rope_norm_nc_cuda<half>(x, dst, ne0, ne1, nb1, nb2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); +} + +static void rope_norm_cuda_nc_f32( + const float * x, float * dst, int ne0, int ne1, int nb1, int nb2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { + + rope_norm_nc_cuda<float>(x, dst, ne0, ne1, nb1, nb2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); +} + static void rope_neox_cuda_f16( const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { @@ -185,6 +333,20 @@ static void rope_neox_cuda_f32( rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } +static void rope_neox_cuda_nc_f16( + const half * x, half * dst, int ne0, int ne1, int nb1, int nb2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { + + rope_neox_nc_cuda<half>(x, dst, ne0, ne1, nb1, nb2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); +} + +static void rope_neox_cuda_nc_f32( + const float * x, float * dst, int ne0, int ne1, int nb1, int nb2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { + + rope_neox_nc_cuda<float>(x, dst, ne0, ne1, nb1, nb2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); +} + void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; @@ -196,7 +358,8 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); - GGML_ASSERT(ggml_is_contiguous(src0)); + const bool is_contiguous = ggml_is_contiguous(src0); + //GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == dst->type); @@ -239,33 +402,69 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v); // compute - if (is_neox) { - if (src0->type == GGML_TYPE_F32) { - rope_neox_cuda_f32( - (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, stream - ); - } else if (src0->type == GGML_TYPE_F16) { - rope_neox_cuda_f16( - (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, stream - ); + if (is_contiguous) { + if (is_neox) { + if (src0->type == GGML_TYPE_F32) { + rope_neox_cuda_f32( + (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, + attn_factor, corr_dims, freq_factors, stream + ); + } else if (src0->type == GGML_TYPE_F16) { + rope_neox_cuda_f16( + (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, + attn_factor, corr_dims, freq_factors, stream + ); + } else { + GGML_ABORT("fatal error"); + } } else { - GGML_ABORT("fatal error"); + if (src0->type == GGML_TYPE_F32) { + rope_norm_cuda_f32( + (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, + attn_factor, corr_dims, freq_factors, stream + ); + } else if (src0->type == GGML_TYPE_F16) { + rope_norm_cuda_f16( + (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, + attn_factor, corr_dims, freq_factors, stream + ); + } else { + GGML_ABORT("fatal error"); + } } } else { - if (src0->type == GGML_TYPE_F32) { - rope_norm_cuda_f32( - (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, stream - ); - } else if (src0->type == GGML_TYPE_F16) { - rope_norm_cuda_f16( - (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, stream - ); + if (is_neox) { + if (src0->type == GGML_TYPE_F32) { + rope_neox_cuda_nc_f32( + (const float *)src0_d, (float *)dst_d, ne00, ne01, src0->nb[1]/sizeof(float), src0->nb[2]/sizeof(float), + n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, + attn_factor, corr_dims, freq_factors, stream + ); + } else if (src0->type == GGML_TYPE_F16) { + rope_neox_cuda_nc_f16( + (const half *)src0_d, (half *)dst_d, ne00, ne01, src0->nb[1]/sizeof(half), src0->nb[2]/sizeof(half), + n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, + attn_factor, corr_dims, freq_factors, stream + ); + } else { + GGML_ABORT("fatal error"); + } } else { - GGML_ABORT("fatal error"); + if (src0->type == GGML_TYPE_F32) { + rope_norm_cuda_nc_f32( + (const float *)src0_d, (float *)dst_d, ne00, ne01, src0->nb[1]/sizeof(float), src0->nb[2]/sizeof(float), + n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, + attn_factor, corr_dims, freq_factors, stream + ); + } else if (src0->type == GGML_TYPE_F16) { + rope_norm_cuda_nc_f16( + (const half *)src0_d, (half *)dst_d, ne00, ne01, src0->nb[1]/sizeof(half), src0->nb[2]/sizeof(half), + n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, + attn_factor, corr_dims, freq_factors, stream + ); + } else { + GGML_ABORT("fatal error"); + } } } } diff --git a/src/llama.cpp b/src/llama.cpp index d52590a6..94a939d8 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -10919,23 +10919,22 @@ struct llm_build_context { cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, attn_norm_output); cb(cur, "wqkv", il); - Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd))); - Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd))); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa))); + Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0 * sizeof(float) * (n_embd)); + Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd)); + Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); } else { Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq); Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk); Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); } cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow |