summaryrefslogtreecommitdiff
path: root/ggml-cuda
diff options
context:
space:
mode:
authorliuwei-git <14815172+liuwei-git@users.noreply.github.com>2024-05-22 04:28:32 +0800
committerGitHub <noreply@github.com>2024-05-21 23:28:32 +0300
commit201cc11afa0a1950e1f632390b2ac6c937a0d8f0 (patch)
tree440fb7ecd80b48772a955a80855db29677d172a2 /ggml-cuda
parent6369bf04336ab60e5c892dd77a3246df91015147 (diff)
llama : add phi3 128K model support (#7225)
* add phi3 128k support in convert-hf-to-gguf * add phi3 128k support in cuda * address build warnings on llama.cpp * adjust index value in cuda long rope freq factors * add long rope support in ggml cpu backend * make freq factors only depend on ctx size * remove unused rope scaling type 'su' frin gguf converter * fix flint warnings on convert-hf-to-gguf.py * set to the short freq factor when context size is small than trained context size * add one line of comments * metal : support rope freq_factors * ggml : update ggml_rope_ext API to support freq. factors * backends : add dev messages to support rope freq. factors * minor : style * tests : update to use new rope API * backends : fix pragma semicolons * minor : cleanup * llama : move rope factors from KV header to tensors * llama : remove tmp assert * cuda : fix compile warning * convert : read/write n_head_kv * llama : fix uninitialized tensors --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'ggml-cuda')
-rw-r--r--ggml-cuda/rope.cu72
1 files changed, 48 insertions, 24 deletions
diff --git a/ggml-cuda/rope.cu b/ggml-cuda/rope.cu
index 4b0d2e5a..4a558f4b 100644
--- a/ggml-cuda/rope.cu
+++ b/ggml-cuda/rope.cu
@@ -58,10 +58,10 @@ static __global__ void rope(
dst[i + 1] = x0*sin_theta + x1*cos_theta;
}
-template<typename T, bool has_pos>
+template<typename T, bool has_pos, bool has_freq_facs>
static __global__ void rope_neox(
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, const float * freq_factors
) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
@@ -88,7 +88,9 @@ static __global__ void rope_neox(
float cur_rot = inv_ndims * ic - ib;
const int p = has_pos ? pos[i2] : 0;
- const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f);
+ const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
+
+ const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f)/freq_factor;
float cos_theta, sin_theta;
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
@@ -164,7 +166,7 @@ static void rope_cuda(
template<typename T>
static void rope_neox_cuda(
const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
) {
GGML_ASSERT(ncols % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
@@ -175,15 +177,29 @@ static void rope_neox_cuda(
const float inv_ndims = -1.0f / n_dims;
if (pos == nullptr) {
- rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
- x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
- theta_scale, inv_ndims
- );
+ if (freq_factors == nullptr) {
+ rope_neox<T, false, false><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+ theta_scale, inv_ndims, freq_factors
+ );
+ } else {
+ rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+ theta_scale, inv_ndims, freq_factors
+ );
+ }
} else {
- rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
- x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
- theta_scale, inv_ndims
- );
+ if (freq_factors == nullptr) {
+ rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+ theta_scale, inv_ndims, freq_factors
+ );
+ } else {
+ rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+ theta_scale, inv_ndims, freq_factors
+ );
+ }
}
}
@@ -214,24 +230,27 @@ static void rope_cuda_f32(
static void rope_neox_cuda_f16(
const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
- rope_neox_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
+ rope_neox_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}
static void rope_neox_cuda_f32(
const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
) {
- rope_neox_cuda<float>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
+ rope_neox_cuda<float>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
+ const ggml_tensor * src2 = dst->src[2];
+
const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
+
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
@@ -241,7 +260,6 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
- const int64_t ne2 = dst->ne[2];
const int64_t nrows = ggml_nrows(src0);
//const int n_past = ((int32_t *) dst->op_params)[0];
@@ -259,16 +277,22 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
+ const float * freq_factors = nullptr;
const int32_t * pos = nullptr;
- if ((mode & 1) == 0) {
- GGML_ASSERT(src1->type == GGML_TYPE_I32);
- GGML_ASSERT(src1->ne[0] == ne2);
- pos = (const int32_t *) src1_d;
- }
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
+ if (is_neox) {
+ pos = (const int32_t *) src1_d;
+
+ if (src2 != nullptr) {
+ freq_factors = (const float *) src2->data;
+ }
+ } else {
+ GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox");
+ }
+
rope_corr_dims corr_dims;
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
@@ -280,12 +304,12 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
if (src0->type == GGML_TYPE_F32) {
rope_neox_cuda_f32(
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
- attn_factor, corr_dims, stream
+ attn_factor, corr_dims, freq_factors, stream
);
} else if (src0->type == GGML_TYPE_F16) {
rope_neox_cuda_f16(
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
- attn_factor, corr_dims, stream
+ attn_factor, corr_dims, freq_factors, stream
);
} else {
GGML_ASSERT(false);