summaryrefslogtreecommitdiff
path: root/ggml-cuda.cu
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r--ggml-cuda.cu153
1 files changed, 111 insertions, 42 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 4e6e7cd9..12ee10e3 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -4493,11 +4493,41 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
cpy_1(cx + x_offset, cdst + dst_offset);
}
-// rope == RoPE == rotary positional embedding
+static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
+ const float y = (i0 / 2 - low) / max(0.001f, high - low);
+ return 1.0f - min(1.0f, max(0.0f, y));
+}
+
+struct rope_corr_dims {
+ float v[4];
+};
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+static __device__ void rope_yarn(
+ float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
+ float * cos_theta, float * sin_theta
+) {
+ // Get n-d rotational scaling corrected for extrapolation
+ float theta_interp = freq_scale * theta_extrap;
+ float theta = theta_interp;
+ if (ext_factor != 0.0f) {
+ float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
+ theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+ // Get n-d magnitude scaling corrected for interpolation
+ mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
+ }
+ *cos_theta = cosf(theta) * mscale;
+ *sin_theta = sinf(theta) * mscale;
+}
+
+// rope == RoPE == rotary positional embedding
template<typename T, bool has_pos>
-static __global__ void rope(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale,
- const int p_delta_rows, const float theta_scale) {
+static __global__ void rope(
+ const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims
+) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (col >= ncols) {
@@ -4509,10 +4539,10 @@ static __global__ void rope(const T * x, T * dst, const int ncols, const int32_t
const int i2 = row/p_delta_rows;
const int p = has_pos ? pos[i2] : 0;
- const float p0 = p*freq_scale;
- const float theta = p0*powf(theta_scale, col/2);
- const float sin_theta = sinf(theta);
- const float cos_theta = cosf(theta);
+ const float theta_base = p*powf(freq_base, -col/ncols);
+
+ float cos_theta, sin_theta;
+ rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
const float x0 = x[i + 0];
const float x1 = x[i + 1];
@@ -4522,8 +4552,10 @@ static __global__ void rope(const T * x, T * dst, const int ncols, const int32_t
}
template<typename T, bool has_pos>
-static __global__ void rope_neox(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale,
- const int p_delta_rows, const float theta_scale) {
+static __global__ void rope_neox(
+ const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims
+) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (col >= ncols) {
@@ -4534,11 +4566,14 @@ static __global__ void rope_neox(const T * x, T * dst, const int ncols, const in
const int i = row*ncols + col/2;
const int i2 = row/p_delta_rows;
+ // simplified from `(row * ncols + col) * (-1 / ncols)`
+ const float cur_rot = -col/ncols - row;
+
const int p = has_pos ? pos[i2] : 0;
- const float p0 = p*freq_scale;
- const float theta = p0*powf(theta_scale, col/2);
- const float sin_theta = sinf(theta);
- const float cos_theta = cosf(theta);
+ const float theta_base = p*powf(freq_base, cur_rot);
+
+ float cos_theta, sin_theta;
+ rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
const float x0 = x[i + 0];
const float x1 = x[i + ncols/2];
@@ -4547,8 +4582,10 @@ static __global__ void rope_neox(const T * x, T * dst, const int ncols, const in
dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
}
-static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale,
- const int p_delta_rows, const float theta_scale, const int n_ctx) {
+static __global__ void rope_glm_f32(
+ const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
+ int n_ctx
+) {
const int col = blockDim.x*blockIdx.x + threadIdx.x;
const int half_n_dims = ncols/4;
@@ -4560,7 +4597,7 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
const int i = row*ncols + col;
const int i2 = row/p_delta_rows;
- const float col_theta_scale = powf(theta_scale, col);
+ const float col_theta_scale = powf(freq_base, -2.0f*col/ncols);
// FIXME: this is likely wrong
const int p = pos != nullptr ? pos[i2] : 0;
@@ -5584,40 +5621,54 @@ static void clamp_f32_cuda(const float * x, float * dst, const float min, const
}
template<typename T>
-static void rope_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
- const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
+static void rope_cuda(
+ const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
+) {
GGML_ASSERT(ncols % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nrows, num_blocks_x, 1);
if (pos == nullptr) {
- rope<T, false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
+ rope<T, false><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
+ );
} else {
- rope<T, true><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
+ rope<T, true><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
+ );
}
}
template<typename T>
-static void rope_neox_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
- const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
+static void rope_neox_cuda(
+ const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
+) {
GGML_ASSERT(ncols % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nrows, num_blocks_x, 1);
if (pos == nullptr) {
- rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
+ rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
+ );
} else {
- rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
+ rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
+ );
}
}
-static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
- const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) {
+static void rope_glm_f32_cuda(
+ const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
+ float freq_base, int n_ctx, cudaStream_t stream
+) {
GGML_ASSERT(ncols % 4 == 0);
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1);
const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE;
const dim3 block_nums(num_blocks_x, nrows, 1);
- rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale, n_ctx);
+ rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, n_ctx);
}
static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
@@ -6477,17 +6528,20 @@ inline void ggml_cuda_op_rope(
const int64_t ne2 = dst->ne[2];
const int64_t nrows = ggml_nrows(src0);
- //const int n_past = ((int32_t *) dst->op_params)[0];
- const int n_dims = ((int32_t *) dst->op_params)[1];
- const int mode = ((int32_t *) dst->op_params)[2];
- const int n_ctx = ((int32_t *) dst->op_params)[3];
- // RoPE alteration for extended context
-
- float freq_base, freq_scale;
- memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
- memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
+ //const int n_past = ((int32_t *) dst->op_params)[0];
+ const int n_dims = ((int32_t *) dst->op_params)[1];
+ const int mode = ((int32_t *) dst->op_params)[2];
+ const int n_ctx = ((int32_t *) dst->op_params)[3];
+ const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
- const float theta_scale = powf(freq_base, -2.0f/n_dims);
+ // RoPE alteration for extended context
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
const int32_t * pos = nullptr;
if ((mode & 1) == 0) {
@@ -6499,24 +6553,39 @@ inline void ggml_cuda_op_rope(
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
+ rope_corr_dims corr_dims;
+ ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
+
// compute
if (is_glm) {
GGML_ASSERT(false);
- rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, n_ctx, main_stream);
+ rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream);
} else if (is_neox) {
GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
if (src0->type == GGML_TYPE_F32) {
- rope_neox_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
+ rope_neox_cuda(
+ (const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
+ attn_factor, corr_dims, main_stream
+ );
} else if (src0->type == GGML_TYPE_F16) {
- rope_neox_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
+ rope_neox_cuda(
+ (const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
+ attn_factor, corr_dims, main_stream
+ );
} else {
GGML_ASSERT(false);
}
} else {
if (src0->type == GGML_TYPE_F32) {
- rope_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
+ rope_cuda(
+ (const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
+ attn_factor, corr_dims, main_stream
+ );
} else if (src0->type == GGML_TYPE_F16) {
- rope_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
+ rope_cuda(
+ (const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
+ attn_factor, corr_dims, main_stream
+ );
} else {
GGML_ASSERT(false);
}