summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2024-09-28 17:41:21 +0300
committerGitHub <noreply@github.com>2024-09-28 17:41:21 +0300
commit7abcc6cc0b0a48b780bb0877e4720c46a7e3c255 (patch)
tree0fbe1e7c3d02462ba3972f9be88a39ab9a3c0c30
parent737514fd814d944f8ce965620293a16e5e8a285d (diff)
CUDA non-contiguous RoPE (#66)
In this way we can avoid the Q, K, V copies being made after multiplication with the QKV tensor in, e.g., Phi-3.5-mini. This results in a 6-7% speedup of PP-512(Phi-3.5-mini) on CUDA (RTX-4080) Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/ggml-cuda.cu5
-rw-r--r--ggml/src/ggml-cuda/rope.cu247
-rw-r--r--src/llama.cpp11
3 files changed, 231 insertions, 32 deletions
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index 966c91c0..3d24cc6f 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -2924,9 +2924,10 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_CAP_MAX:
- return true;
case GGML_OP_ROPE:
- return ggml_is_contiguous(op->src[0]);
+ return true;
+ //case GGML_OP_ROPE:
+ // return ggml_is_contiguous(op->src[0]);
case GGML_OP_IM2COL:
case GGML_OP_POOL_2D:
case GGML_OP_SUM_ROWS:
diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu
index 99ec1dd9..42ca72af 100644
--- a/ggml/src/ggml-cuda/rope.cu
+++ b/ggml/src/ggml-cuda/rope.cu
@@ -69,6 +69,49 @@ static __global__ void rope_norm(
}
template<typename T, bool has_ff>
+static __global__ void rope_norm_nc(
+ const T * x, T * dst, int ne0, int ne1, int nb1, int nb2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
+ const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
+
+ if (i0 >= ne0) {
+ return;
+ }
+
+ const int row = blockDim.x*blockIdx.x + threadIdx.x;
+ const int j2 = row/ne1;
+ const int j1 = row%ne1;
+ const T * xx = x + j1*nb1 + j2*nb2;
+
+ if (i0 >= n_dims) {
+ const int i = row*ne0 + i0;
+
+ dst[i + 0] = xx[i0 + 0];
+ dst[i + 1] = xx[i0 + 1];
+
+ return;
+ }
+
+ const int i = row*ne0 + i0;
+ const int i2 = row/p_delta_rows;
+
+ const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
+
+ const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
+
+ float cos_theta;
+ float sin_theta;
+
+ rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+
+ const float x0 = xx[i0 + 0];
+ const float x1 = xx[i0 + 1];
+
+ dst[i + 0] = x0*cos_theta - x1*sin_theta;
+ dst[i + 1] = x0*sin_theta + x1*cos_theta;
+}
+
+template<typename T, bool has_ff>
static __global__ void rope_neox(
const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
@@ -108,6 +151,49 @@ static __global__ void rope_neox(
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
}
+template<typename T, bool has_ff>
+static __global__ void rope_neox_nc(
+ const T * x, T * dst, int ne0, int ne1, int nb1, int nb2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
+ const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
+
+ if (i0 >= ne0) {
+ return;
+ }
+
+ const int row = blockDim.x*blockIdx.x + threadIdx.x;
+ const int j2 = row/ne1;
+ const int j1 = row%ne1;
+ const T * xx = x + j1*nb1 + j2*nb2;
+
+ if (i0 >= n_dims) {
+ const int i = row*ne0 + i0;
+
+ dst[i + 0] = xx[i0 + 0];
+ dst[i + 1] = xx[i0 + 1];
+
+ return;
+ }
+
+ const int i = row*ne0 + i0/2;
+ const int i2 = row/p_delta_rows;
+
+ const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
+
+ const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
+
+ float cos_theta;
+ float sin_theta;
+
+ rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+
+ const float x0 = xx[i0/2 + 0];
+ const float x1 = xx[i0/2 + n_dims/2];
+
+ dst[i + 0] = x0*cos_theta - x1*sin_theta;
+ dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
+}
+
template<typename T>
static void rope_norm_cuda(
const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
@@ -133,6 +219,30 @@ static void rope_norm_cuda(
}
template<typename T>
+static void rope_norm_nc_cuda(
+ const T * x, T * dst, int ne0, int ne1, int nb1, int nb2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+ GGML_ASSERT(ne0 % 2 == 0);
+ const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
+ const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+ const dim3 block_nums(nr, n_blocks_x, 1);
+
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+ if (freq_factors == nullptr) {
+ rope_norm_nc<T, false><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ne0, ne1, nb1, nb2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+ theta_scale, freq_factors
+ );
+ } else {
+ rope_norm_nc<T, true><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ne0, ne1, nb1, nb2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+ theta_scale, freq_factors
+ );
+ }
+}
+
+template<typename T>
static void rope_neox_cuda(
const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
@@ -156,6 +266,30 @@ static void rope_neox_cuda(
}
}
+template<typename T>
+static void rope_neox_nc_cuda(
+ const T * x, T * dst, int ne0, int ne1, int nb1, int nb2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+ GGML_ASSERT(ne0 % 2 == 0);
+ const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
+ const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+ const dim3 block_nums(nr, n_blocks_x, 1);
+
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+ if (freq_factors == nullptr) {
+ rope_neox_nc<T, false><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ne0, ne1, nb1, nb2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+ theta_scale, freq_factors
+ );
+ } else {
+ rope_neox_nc<T, true><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ne0, ne1, nb1, nb2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+ theta_scale, freq_factors
+ );
+ }
+}
+
static void rope_norm_cuda_f16(
const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
@@ -170,6 +304,20 @@ static void rope_norm_cuda_f32(
rope_norm_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}
+static void rope_norm_cuda_nc_f16(
+ const half * x, half * dst, int ne0, int ne1, int nb1, int nb2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+
+ rope_norm_nc_cuda<half>(x, dst, ne0, ne1, nb1, nb2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
+}
+
+static void rope_norm_cuda_nc_f32(
+ const float * x, float * dst, int ne0, int ne1, int nb1, int nb2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+
+ rope_norm_nc_cuda<float>(x, dst, ne0, ne1, nb1, nb2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
+}
+
static void rope_neox_cuda_f16(
const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
@@ -185,6 +333,20 @@ static void rope_neox_cuda_f32(
rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}
+static void rope_neox_cuda_nc_f16(
+ const half * x, half * dst, int ne0, int ne1, int nb1, int nb2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+
+ rope_neox_nc_cuda<half>(x, dst, ne0, ne1, nb1, nb2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
+}
+
+static void rope_neox_cuda_nc_f32(
+ const float * x, float * dst, int ne0, int ne1, int nb1, int nb2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+
+ rope_neox_nc_cuda<float>(x, dst, ne0, ne1, nb1, nb2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
+}
+
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
@@ -196,7 +358,8 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
- GGML_ASSERT(ggml_is_contiguous(src0));
+ const bool is_contiguous = ggml_is_contiguous(src0);
+ //GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
@@ -239,33 +402,69 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
// compute
- if (is_neox) {
- if (src0->type == GGML_TYPE_F32) {
- rope_neox_cuda_f32(
- (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
- attn_factor, corr_dims, freq_factors, stream
- );
- } else if (src0->type == GGML_TYPE_F16) {
- rope_neox_cuda_f16(
- (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
- attn_factor, corr_dims, freq_factors, stream
- );
+ if (is_contiguous) {
+ if (is_neox) {
+ if (src0->type == GGML_TYPE_F32) {
+ rope_neox_cuda_f32(
+ (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+ attn_factor, corr_dims, freq_factors, stream
+ );
+ } else if (src0->type == GGML_TYPE_F16) {
+ rope_neox_cuda_f16(
+ (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+ attn_factor, corr_dims, freq_factors, stream
+ );
+ } else {
+ GGML_ABORT("fatal error");
+ }
} else {
- GGML_ABORT("fatal error");
+ if (src0->type == GGML_TYPE_F32) {
+ rope_norm_cuda_f32(
+ (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+ attn_factor, corr_dims, freq_factors, stream
+ );
+ } else if (src0->type == GGML_TYPE_F16) {
+ rope_norm_cuda_f16(
+ (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+ attn_factor, corr_dims, freq_factors, stream
+ );
+ } else {
+ GGML_ABORT("fatal error");
+ }
}
} else {
- if (src0->type == GGML_TYPE_F32) {
- rope_norm_cuda_f32(
- (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
- attn_factor, corr_dims, freq_factors, stream
- );
- } else if (src0->type == GGML_TYPE_F16) {
- rope_norm_cuda_f16(
- (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
- attn_factor, corr_dims, freq_factors, stream
- );
+ if (is_neox) {
+ if (src0->type == GGML_TYPE_F32) {
+ rope_neox_cuda_nc_f32(
+ (const float *)src0_d, (float *)dst_d, ne00, ne01, src0->nb[1]/sizeof(float), src0->nb[2]/sizeof(float),
+ n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+ attn_factor, corr_dims, freq_factors, stream
+ );
+ } else if (src0->type == GGML_TYPE_F16) {
+ rope_neox_cuda_nc_f16(
+ (const half *)src0_d, (half *)dst_d, ne00, ne01, src0->nb[1]/sizeof(half), src0->nb[2]/sizeof(half),
+ n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+ attn_factor, corr_dims, freq_factors, stream
+ );
+ } else {
+ GGML_ABORT("fatal error");
+ }
} else {
- GGML_ABORT("fatal error");
+ if (src0->type == GGML_TYPE_F32) {
+ rope_norm_cuda_nc_f32(
+ (const float *)src0_d, (float *)dst_d, ne00, ne01, src0->nb[1]/sizeof(float), src0->nb[2]/sizeof(float),
+ n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+ attn_factor, corr_dims, freq_factors, stream
+ );
+ } else if (src0->type == GGML_TYPE_F16) {
+ rope_norm_cuda_nc_f16(
+ (const half *)src0_d, (half *)dst_d, ne00, ne01, src0->nb[1]/sizeof(half), src0->nb[2]/sizeof(half),
+ n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+ attn_factor, corr_dims, freq_factors, stream
+ );
+ } else {
+ GGML_ABORT("fatal error");
+ }
}
}
}
diff --git a/src/llama.cpp b/src/llama.cpp
index d52590a6..94a939d8 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -10919,23 +10919,22 @@ struct llm_build_context {
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, attn_norm_output);
cb(cur, "wqkv", il);
- Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd)));
- Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd)));
- Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)));
+ Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0 * sizeof(float) * (n_embd));
+ Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd));
+ Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa));
}
else {
Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
}
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
- Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
- Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow