summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2024-05-22 11:01:35 +0300
committerGitHub <noreply@github.com>2024-05-22 11:01:35 +0300
commit3e5faa85032ec3106a2ad831bf412be9ff139f47 (patch)
treedc85a6e015eecb4c771bffd6d3c4202459edaa9a
parent201cc11afa0a1950e1f632390b2ac6c937a0d8f0 (diff)
cuda : fix rope + add tests (#7452)
* cuda : fix rope pos data ggml-ci * ggml : drop mode & 1 == 1 support for ggml_rope ggml-ci * ggml : support freq_factors for f16 rope (CPU) ggml-ci * tests : add rope tests using frequency factors ggml-ci
-rw-r--r--ggml-cuda/rope.cu4
-rw-r--r--ggml.c20
-rw-r--r--ggml.h2
-rw-r--r--tests/test-backend-ops.cpp41
4 files changed, 47 insertions, 20 deletions
diff --git a/ggml-cuda/rope.cu b/ggml-cuda/rope.cu
index 4a558f4b..50f2cf41 100644
--- a/ggml-cuda/rope.cu
+++ b/ggml-cuda/rope.cu
@@ -283,9 +283,9 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
- if (is_neox) {
- pos = (const int32_t *) src1_d;
+ pos = (const int32_t *) src1_d;
+ if (is_neox) {
if (src2 != nullptr) {
freq_factors = (const float *) src2->data;
}
diff --git a/ggml.c b/ggml.c
index 37b16b7a..d316e3d3 100644
--- a/ggml.c
+++ b/ggml.c
@@ -6245,6 +6245,8 @@ static struct ggml_tensor * ggml_rope_impl(
float xpos_base,
bool xpos_down,
bool inplace) {
+ GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
+
GGML_ASSERT(ggml_is_vector(b));
GGML_ASSERT(b->type == GGML_TYPE_I32);
GGML_ASSERT(a->ne[2] == b->ne[0]);
@@ -14413,7 +14415,7 @@ static void ggml_compute_forward_rope_f32(
freq_factors = (const float *) src2->data;
}
} else {
- GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for mode 1");
+ GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
}
// backward process uses inverse rotation by cos and sin.
@@ -14529,6 +14531,7 @@ static void ggml_compute_forward_rope_f32(
}
}
+// TODO: deduplicate f16/f32 code
static void ggml_compute_forward_rope_f16(
const struct ggml_compute_params * params,
struct ggml_tensor * dst,
@@ -14536,6 +14539,7 @@ static void ggml_compute_forward_rope_f16(
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
+ const struct ggml_tensor * src2 = dst->src[2];
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
return;
@@ -14588,6 +14592,17 @@ static void ggml_compute_forward_rope_f16(
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
+ const float * freq_factors = NULL;
+ if (is_neox) {
+ if (src2 != NULL) {
+ GGML_ASSERT(src2->type == GGML_TYPE_F32);
+ GGML_ASSERT(src2->ne[0] >= n_dims / 2);
+ freq_factors = (const float *) src2->data;
+ }
+ } else {
+ GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
+ }
+
// backward process uses inverse rotation by cos and sin.
// cos and sin build a rotation matrix, where the inverse is the transpose.
// this essentially just switches the sign of sin.
@@ -14660,10 +14675,11 @@ static void ggml_compute_forward_rope_f16(
// simplified from `(ib * n_dims + ic) * inv_ndims`
float cur_rot = inv_ndims * ic - ib;
+ float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
float cos_theta, sin_theta;
rope_yarn(
- theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
+ theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
&cos_theta, &sin_theta
);
sin_theta *= sin_sign;
diff --git a/ggml.h b/ggml.h
index 35ac9110..08835042 100644
--- a/ggml.h
+++ b/ggml.h
@@ -1460,7 +1460,7 @@ extern "C" {
struct ggml_tensor * b);
// rotary position embedding
- // if mode & 1 == 1, skip n_past elements (DEPRECATED)
+ // if mode & 1 == 1, skip n_past elements (NOT SUPPORTED)
// if mode & 2 == 1, GPT-NeoX style
// if mode & 4 == 1, ChatGLM style
//
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 1493a7ca..de74585d 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -1142,20 +1142,22 @@ struct test_rope : public test_case {
int n_dims;
int mode;
int n_ctx;
+ bool ff;
std::string vars() override {
- return VARS_TO_STR5(type, ne, n_dims, mode, n_ctx);
+ return VARS_TO_STR6(type, ne, n_dims, mode, n_ctx, ff);
}
test_rope(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 10, 10, 1},
- int n_dims = 10, int mode = 0, int n_ctx = 512)
- : type(type), ne(ne), n_dims(n_dims), mode(mode), n_ctx(n_ctx) {}
+ int n_dims = 10, int mode = 0, int n_ctx = 512, bool ff = false)
+ : type(type), ne(ne), n_dims(n_dims), mode(mode), n_ctx(n_ctx), ff(ff) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]);
- ggml_tensor * out = ggml_rope(ctx, a, pos, n_dims, mode, n_ctx);
+ ggml_tensor * freq = ff ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2) : nullptr;
+ ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f);
return out;
}
@@ -1169,7 +1171,12 @@ struct test_rope : public test_case {
}
ggml_backend_tensor_set(t, data.data(), 0, ne[2] * sizeof(int));
} else {
- init_tensor_uniform(t);
+ if (t->ne[0] == n_dims/2) {
+ // frequency factors in the range [0.9f, 1.1f]
+ init_tensor_uniform(t, 0.9f, 1.1f);
+ } else {
+ init_tensor_uniform(t);
+ }
}
}
}
@@ -2188,16 +2195,20 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f));
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
- test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512)); // llama 7B
- test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512)); // llama 13B
- test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512)); // llama 30B
- test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512)); // llama 65B
- test_cases.emplace_back(new test_rope(type, { 64, 1, 10, 1}, 64, 2, 512)); // neox (falcon 7B)
- test_cases.emplace_back(new test_rope(type, { 64, 71, 10, 1}, 64, 2, 512)); // neox (falcon 7B)
- test_cases.emplace_back(new test_rope(type, { 64, 8, 10, 1}, 64, 2, 512)); // neox (falcon 40B)
- test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512)); // neox (falcon 40B)
- test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 20, 2, 512)); // neox (stablelm)
- test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 32, 2, 512)); // neox (phi-2)
+ // TODO: ff not supported yet for !neox
+ test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512, false)); // llama 7B
+ test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512, false)); // llama 13B
+ test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512, false)); // llama 30B
+ test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512, false)); // llama 65B
+
+ for (bool ff : {false, true}) { // freq_factors
+ test_cases.emplace_back(new test_rope(type, { 64, 1, 10, 1}, 64, 2, 512, ff)); // neox (falcon 7B)
+ test_cases.emplace_back(new test_rope(type, { 64, 71, 10, 1}, 64, 2, 512, ff)); // neox (falcon 7B)
+ test_cases.emplace_back(new test_rope(type, { 64, 8, 10, 1}, 64, 2, 512, ff)); // neox (falcon 40B)
+ test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512, ff)); // neox (falcon 40B)
+ test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 20, 2, 512, ff)); // neox (stablelm)
+ test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 32, 2, 512, ff)); // neox (phi-2)
+ }
}
test_cases.emplace_back(new test_concat(GGML_TYPE_F32));