diff options
Diffstat (limited to 'tests/test-backend-ops.cpp')
| -rw-r--r-- | tests/test-backend-ops.cpp | 31 |
1 files changed, 16 insertions, 15 deletions
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 8dc90a45..ce406a8a 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1141,7 +1141,7 @@ struct test_rope : public test_case { const std::array<int64_t, 4> ne_a; int n_dims; int mode; - int n_ctx; + int n_ctx; // used to generate positions float fs; // freq_scale float ef; // ext_factor float af; // attn_factor @@ -1168,7 +1168,7 @@ struct test_rope : public test_case { } ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]); ggml_tensor * freq = ff ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2) : nullptr; - ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, n_ctx, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); return out; } @@ -1615,7 +1615,7 @@ struct llama_hparams { // cparams static constexpr uint32_t n_ctx = 512; // user-specified context size - static constexpr uint32_t n_orig_ctx = n_ctx; + static constexpr uint32_t n_ctx_orig = n_ctx; // batch int32_t n_tokens; @@ -1806,13 +1806,13 @@ struct test_llama : public test_llm { Qcur = ggml_rope_ext( ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head, hp.n_tokens), inp_pos, nullptr, - hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale, + hp.n_rot, 0, hp.n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); Kcur = ggml_rope_ext( ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos, nullptr, - hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale, + hp.n_rot, 0, hp.n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -1931,12 +1931,12 @@ struct test_falcon : public test_llm { // using mode = 2 for neox mode Qcur = ggml_rope_ext( - ctx, Qcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx, + ctx, Qcur, inp_pos, nullptr, hp.n_rot, 2, hp.n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); Kcur = ggml_rope_ext( - ctx, Kcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx, + ctx, Kcur, inp_pos, nullptr, hp.n_rot, 2, hp.n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -2236,15 +2236,15 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op for (float ef : { 0.0f, 0.7465f }) { for (float af : { 1.0f, 1.4245f }) { for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { - // TODO: ff not supported yet for !neox - test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 7B - if (all) { - test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 13B - test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 30B - test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 65B - } - for (bool ff : {false, true}) { // freq_factors + test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 7B + + if (all) { + test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 13B + test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 30B + test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 65B + } + if (all) { test_cases.emplace_back(new test_rope(type, { 64, 1, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B) test_cases.emplace_back(new test_rope(type, { 64, 71, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B) @@ -2256,6 +2256,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B) } } + all = false; } } |
