summaryrefslogtreecommitdiff
path: root/ggml.h
diff options
context:
space:
mode:
authorliuwei-git <14815172+liuwei-git@users.noreply.github.com>2024-05-22 04:28:32 +0800
committerGitHub <noreply@github.com>2024-05-21 23:28:32 +0300
commit201cc11afa0a1950e1f632390b2ac6c937a0d8f0 (patch)
tree440fb7ecd80b48772a955a80855db29677d172a2 /ggml.h
parent6369bf04336ab60e5c892dd77a3246df91015147 (diff)
llama : add phi3 128K model support (#7225)
* add phi3 128k support in convert-hf-to-gguf * add phi3 128k support in cuda * address build warnings on llama.cpp * adjust index value in cuda long rope freq factors * add long rope support in ggml cpu backend * make freq factors only depend on ctx size * remove unused rope scaling type 'su' frin gguf converter * fix flint warnings on convert-hf-to-gguf.py * set to the short freq factor when context size is small than trained context size * add one line of comments * metal : support rope freq_factors * ggml : update ggml_rope_ext API to support freq. factors * backends : add dev messages to support rope freq. factors * minor : style * tests : update to use new rope API * backends : fix pragma semicolons * minor : cleanup * llama : move rope factors from KV header to tensors * llama : remove tmp assert * cuda : fix compile warning * convert : read/write n_head_kv * llama : fix uninitialized tensors --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'ggml.h')
-rw-r--r--ggml.h45
1 files changed, 36 insertions, 9 deletions
diff --git a/ggml.h b/ggml.h
index 77475710..35ac9110 100644
--- a/ggml.h
+++ b/ggml.h
@@ -1465,6 +1465,7 @@ extern "C" {
// if mode & 4 == 1, ChatGLM style
//
// b is an int32 vector with size a->ne[2], it contains the positions
+ // c is freq factors (e.g. phi3-128k), (optional)
GGML_API struct ggml_tensor * ggml_rope(
struct ggml_context * ctx,
struct ggml_tensor * a,
@@ -1483,10 +1484,11 @@ extern "C" {
int n_ctx);
// custom RoPE
- GGML_API struct ggml_tensor * ggml_rope_custom(
+ GGML_API struct ggml_tensor * ggml_rope_ext(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
+ struct ggml_tensor * c,
int n_dims,
int mode,
int n_ctx,
@@ -1499,10 +1501,11 @@ extern "C" {
float beta_slow);
// in-place, returns view(a)
- GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
+ GGML_API struct ggml_tensor * ggml_rope_ext_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
+ struct ggml_tensor * c,
int n_dims,
int mode,
int n_ctx,
@@ -1514,18 +1517,41 @@ extern "C" {
float beta_fast,
float beta_slow);
- // compute correction dims for YaRN RoPE scaling
- GGML_CALL void ggml_rope_yarn_corr_dims(
- int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
+ GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int n_dims,
+ int mode,
+ int n_ctx,
+ int n_orig_ctx,
+ float freq_base,
+ float freq_scale,
+ float ext_factor,
+ float attn_factor,
+ float beta_fast,
+ float beta_slow),
+ "use ggml_rope_ext instead");
- // xPos RoPE, in-place, returns view(a)
- GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(
+ GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int n_dims,
- float base,
- bool down);
+ int mode,
+ int n_ctx,
+ int n_orig_ctx,
+ float freq_base,
+ float freq_scale,
+ float ext_factor,
+ float attn_factor,
+ float beta_fast,
+ float beta_slow),
+ "use ggml_rope_ext_inplace instead");
+
+ // compute correction dims for YaRN RoPE scaling
+ GGML_CALL void ggml_rope_yarn_corr_dims(
+ int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
// rotary position embedding backward, i.e compute dx from dy
// a - dy
@@ -1533,6 +1559,7 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
+ struct ggml_tensor * c,
int n_dims,
int mode,
int n_ctx,