summaryrefslogtreecommitdiff
path: root/tests/test-backend-ops.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test-backend-ops.cpp')
-rw-r--r--tests/test-backend-ops.cpp31
1 files changed, 16 insertions, 15 deletions
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 8dc90a45..ce406a8a 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -1141,7 +1141,7 @@ struct test_rope : public test_case {
const std::array<int64_t, 4> ne_a;
int n_dims;
int mode;
- int n_ctx;
+ int n_ctx; // used to generate positions
float fs; // freq_scale
float ef; // ext_factor
float af; // attn_factor
@@ -1168,7 +1168,7 @@ struct test_rope : public test_case {
}
ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]);
ggml_tensor * freq = ff ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2) : nullptr;
- ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, n_ctx, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
+ ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
return out;
}
@@ -1615,7 +1615,7 @@ struct llama_hparams {
// cparams
static constexpr uint32_t n_ctx = 512; // user-specified context size
- static constexpr uint32_t n_orig_ctx = n_ctx;
+ static constexpr uint32_t n_ctx_orig = n_ctx;
// batch
int32_t n_tokens;
@@ -1806,13 +1806,13 @@ struct test_llama : public test_llm {
Qcur = ggml_rope_ext(
ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head, hp.n_tokens), inp_pos, nullptr,
- hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
+ hp.n_rot, 0, hp.n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_ext(
ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos, nullptr,
- hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
+ hp.n_rot, 0, hp.n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
@@ -1931,12 +1931,12 @@ struct test_falcon : public test_llm {
// using mode = 2 for neox mode
Qcur = ggml_rope_ext(
- ctx, Qcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx,
+ ctx, Qcur, inp_pos, nullptr, hp.n_rot, 2, hp.n_ctx_orig,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_ext(
- ctx, Kcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx,
+ ctx, Kcur, inp_pos, nullptr, hp.n_rot, 2, hp.n_ctx_orig,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
@@ -2236,15 +2236,15 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
for (float ef : { 0.0f, 0.7465f }) {
for (float af : { 1.0f, 1.4245f }) {
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
- // TODO: ff not supported yet for !neox
- test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 7B
- if (all) {
- test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 13B
- test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 30B
- test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 65B
- }
-
for (bool ff : {false, true}) { // freq_factors
+ test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 7B
+
+ if (all) {
+ test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 13B
+ test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 30B
+ test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 65B
+ }
+
if (all) {
test_cases.emplace_back(new test_rope(type, { 64, 1, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B)
test_cases.emplace_back(new test_rope(type, { 64, 71, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B)
@@ -2256,6 +2256,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B)
}
}
+
all = false;
}
}