summaryrefslogtreecommitdiff
path: root/tests/test-backend-ops.cpp
diff options
context:
space:
mode:
authorJared Van Bortel <jared@nomic.ai>2024-01-29 15:50:50 -0500
committerGitHub <noreply@github.com>2024-01-29 15:50:50 -0500
commitfbf1ddec69f7001cc707de17fa74d7200813bbac (patch)
tree55ff0324a0fe0dfc3de70d232a29b04926657ae1 /tests/test-backend-ops.cpp
parent2aed77eb06a329f0d82bb1c467f4244904d4073f (diff)
Nomic Vulkan backend (#4456)
Signed-off-by: Jared Van Bortel <jared@nomic.ai> Co-authored-by: niansa <anton-sa@web.de> Co-authored-by: Adam Treat <treat.adam@gmail.com> Co-authored-by: Aaron Miller <apage43@ninjawhale.com> Co-authored-by: ToKiNoBug <tokinobug@163.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: slaren <slarengh@gmail.com>
Diffstat (limited to 'tests/test-backend-ops.cpp')
-rw-r--r--tests/test-backend-ops.cpp430
1 files changed, 422 insertions, 8 deletions
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 01593910..775147d4 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -370,12 +370,15 @@ struct test_case {
printf(" %s(%s): ", op_desc(out).c_str(), vars().c_str());
fflush(stdout);
- // check if backends support op
+ // check if the backends support the ops
bool supported = true;
for (ggml_backend_t backend : {backend1, backend2}) {
- if (!ggml_backend_supports_op(backend, out)) {
- printf("not supported [%s] ", ggml_backend_name(backend));
- supported = false;
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+ if (!ggml_backend_supports_op(backend, t)) {
+ printf("not supported [%s] ", ggml_backend_name(backend));
+ supported = false;
+ break;
+ }
}
}
if (!supported) {
@@ -626,6 +629,13 @@ struct test_unary : public test_case {
ggml_tensor * out = ggml_unary(ctx, in, op);
return out;
}
+
+ void initialize_tensors(ggml_context * ctx) override {
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+ // test extended range of values to check for NaNs in GELU
+ init_tensor_uniform(t, -150.f, 150.f);
+ }
+ }
};
// GGML_OP_GET_ROWS
@@ -1066,18 +1076,24 @@ struct test_diag_mask_inf : public test_case {
struct test_soft_max : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
+ const float scale;
+ const bool mask;
std::string vars() override {
- return VARS_TO_STR2(type, ne);
+ return VARS_TO_STR4(type, ne, scale, mask);
}
test_soft_max(ggml_type type = GGML_TYPE_F32,
- std::array<int64_t, 4> ne = {10, 10, 10, 10})
- : type(type), ne(ne) {}
+ std::array<int64_t, 4> ne = {10, 10, 10, 10},
+ float scale = 1.0f,
+ bool mask = false)
+ : type(type), ne(ne), scale(scale), mask(mask) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
- ggml_tensor * out = ggml_soft_max(ctx, a);
+ ggml_tensor * b = nullptr;
+ if (mask) { b = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]); }
+ ggml_tensor * out = ggml_soft_max_ext(ctx, a, b, scale);
return out;
}
};
@@ -1474,6 +1490,393 @@ struct test_moe : public test_case {
}
};
+
+enum llm_norm_type {
+ LLM_NORM,
+ LLM_NORM_RMS,
+};
+
+struct llama_hparams {
+ uint32_t n_vocab;
+ uint32_t n_embd;
+ uint32_t n_head;
+ uint32_t n_head_kv;
+ static constexpr uint32_t n_layer = 1;
+ uint32_t n_rot;
+ uint32_t n_embd_head; // dimension of values (d_v)
+ uint32_t n_ff;
+
+ float f_norm_eps;
+ float f_norm_rms_eps;
+
+ // cparams
+ static constexpr uint32_t n_ctx = 512; // user-specified context size
+ static constexpr uint32_t n_orig_ctx = n_ctx;
+
+ // batch
+ int32_t n_tokens;
+
+ // llm_build_context
+ static constexpr int32_t n_kv = 32; // size of KV cache to consider (n_kv <= n_ctx
+ static constexpr int32_t kv_head = 1; // index of where we store new KV data in the cache
+
+ uint32_t n_embd_gqa() const { // dimension of key embeddings across all k-v heads
+ return n_embd_head * n_head_kv;
+ }
+};
+
+// LLM base class
+struct test_llm : public test_case {
+ llama_hparams hp;
+
+protected:
+ test_llm(llama_hparams hp)
+ : hp(std::move(hp)) {
+ }
+
+public:
+ struct ggml_tensor * llm_build_norm(
+ struct ggml_context * ctx,
+ struct ggml_tensor * cur,
+ struct ggml_tensor * mw,
+ struct ggml_tensor * mb,
+ llm_norm_type type) {
+ switch (type) {
+ case LLM_NORM: cur = ggml_norm (ctx, cur, hp.f_norm_eps); break;
+ case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, hp.f_norm_rms_eps); break;
+ }
+ cur = ggml_mul(ctx, cur, mw);
+ if (mb) {
+ cur = ggml_add(ctx, cur, mb);
+ }
+ return cur;
+ }
+
+ void llm_build_kv_store(
+ struct ggml_context * ctx,
+ struct ggml_tensor * k_l,
+ struct ggml_tensor * v_l,
+ struct ggml_tensor * k_cur,
+ struct ggml_tensor * v_cur) {
+ // compute the transposed [n_tokens, n_embd] V matrix
+ struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, hp.n_embd_gqa(), hp.n_tokens));
+
+ struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, k_l, hp.n_tokens*hp.n_embd_gqa(),
+ (ggml_row_size(k_l->type, hp.n_embd_gqa()))*hp.kv_head);
+
+ struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, v_l, hp.n_tokens, hp.n_embd_gqa(),
+ ( hp.n_ctx)*ggml_element_size(v_l),
+ (hp.kv_head)*ggml_element_size(v_l));
+
+ // important: storing RoPE-ed version of K in the KV cache!
+ ggml_cpy(ctx, k_cur, k_cache_view);
+ ggml_cpy(ctx, v_cur_t, v_cache_view);
+ }
+
+ // if max_alibi_bias > 0 then apply ALiBi
+ struct ggml_tensor * llm_build_kqv(
+ struct ggml_context * ctx,
+ struct ggml_tensor * k_l,
+ struct ggml_tensor * v_l,
+ struct ggml_tensor * q_cur,
+ struct ggml_tensor * kq_mask,
+ float kq_scale) {
+ struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3);
+
+ struct ggml_tensor * k =
+ ggml_view_3d(ctx, k_l,
+ hp.n_embd_head, hp.n_kv, hp.n_head_kv,
+ ggml_row_size(k_l->type, hp.n_embd_gqa()),
+ ggml_row_size(k_l->type, hp.n_embd_head),
+ 0);
+
+ struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
+
+ kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale);
+
+ // split cached v into n_head heads
+ struct ggml_tensor * v =
+ ggml_view_3d(ctx, v_l,
+ hp.n_kv, hp.n_embd_head, hp.n_head_kv,
+ ggml_element_size(v_l)*hp.n_ctx,
+ ggml_element_size(v_l)*hp.n_ctx*hp.n_embd_head,
+ 0);
+
+ struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
+
+ struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
+
+ struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, hp.n_embd_head*hp.n_head, hp.n_tokens);
+
+ struct ggml_tensor * wo = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd);
+ cur = ggml_mul_mat(ctx, wo, cur);
+
+ return cur;
+ }
+
+ void initialize_tensors(ggml_context * ctx) override {
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+ if (t->type == GGML_TYPE_I32) {
+ // pos
+ std::vector<int> data(hp.n_tokens);
+ for (int i = 0; i < hp.n_tokens; i++) {
+ data[i] = rand() % hp.n_ctx;
+ }
+ ggml_backend_tensor_set(t, data.data(), 0, hp.n_tokens * sizeof(int));
+ } else {
+ init_tensor_uniform(t);
+ }
+ }
+ }
+};
+
+
+// Llama
+struct test_llama : public test_llm {
+ static constexpr float freq_base = 10000.0f;
+ static constexpr float freq_scale = 1.0f;
+ static constexpr float ext_factor = 0.0f;
+ static constexpr float attn_factor = 1.0f;
+ static constexpr float beta_fast = 32.0f;
+ static constexpr float beta_slow = 1.0f;
+
+ std::string op_desc(ggml_tensor * t) override {
+ GGML_UNUSED(t);
+ return "LLAMA";
+ }
+
+ std::string vars() override {
+ auto n_tokens = hp.n_tokens;
+ return VARS_TO_STR1(n_tokens);
+ }
+
+ double max_nmse_err() override {
+ return 2e-3;
+ }
+
+ test_llama(int n_tokens = 1)
+ : test_llm({
+ /*n_vocab =*/ 32000,
+ /*n_embd =*/ 3200,
+ /*n_head =*/ 32,
+ /*n_head_kv =*/ 32,
+ /*n_rot =*/ 100,
+ /*n_embd_head =*/ 100,
+ /*n_ff =*/ 8640,
+ /*f_norm_eps =*/ 0.f,
+ /*f_norm_rms_eps =*/ 1e-5f,
+ /*n_tokens =*/ n_tokens,
+ }) {
+ }
+
+ ggml_tensor * build_graph(ggml_context * ctx) override {
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
+
+ inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hp.n_embd, hp.n_tokens);
+
+ // inp_pos - contains the positions
+ struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);
+
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+ struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1);
+
+ ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
+ ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
+
+ for (uint32_t il = 0; il < hp.n_layer; ++il) {
+ struct ggml_tensor * inpSA = inpL;
+
+ // norm
+ ggml_tensor * attn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
+ cur = llm_build_norm(ctx, inpL, attn_norm, nullptr, LLM_NORM_RMS);
+
+ // self-attention
+ {
+ ggml_tensor * wq = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd);
+ ggml_tensor * wk = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd_gqa());
+ ggml_tensor * wv = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd_gqa());
+
+ // compute Q and K and RoPE them
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx, wq, cur);
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctx, wk, cur);
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctx, wv, cur);
+
+ Qcur = ggml_rope_custom(
+ ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head, hp.n_tokens), inp_pos,
+ hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+
+ Kcur = ggml_rope_custom(
+ ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos,
+ hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+
+ llm_build_kv_store(ctx, k_l, v_l, Kcur, Vcur);
+
+ cur = llm_build_kqv(ctx, k_l, v_l, Qcur, KQ_mask, 1.0f/sqrtf(float(hp.n_embd_head)));
+ }
+
+ struct ggml_tensor * ffn_inp = ggml_add(ctx, cur, inpSA);
+
+ // feed-forward network
+ ggml_tensor * ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
+ cur = llm_build_norm(ctx, ffn_inp, ffn_norm, nullptr, LLM_NORM_RMS);
+
+ ggml_tensor * ffn_gate = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_ff);
+ ggml_tensor * ffn_down = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_ff, hp.n_embd);
+ ggml_tensor * ffn_up = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_ff);
+ struct ggml_tensor * tmp = ggml_mul_mat(ctx, ffn_up, cur);
+ cur = ggml_mul_mat(ctx, ffn_gate, cur);
+ cur = ggml_silu(ctx, cur);
+ cur = ggml_mul(ctx, cur, tmp);
+ cur = ggml_mul_mat(ctx, ffn_down, cur);
+
+ cur = ggml_add(ctx, cur, ffn_inp);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ ggml_tensor * output_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
+ cur = llm_build_norm(ctx, cur, output_norm, nullptr, LLM_NORM_RMS);
+
+ // lm_head
+ ggml_tensor * output = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_vocab);
+ cur = ggml_mul_mat(ctx, output, cur);
+
+ return cur;
+ }
+};
+
+// Falcon
+struct test_falcon : public test_llm {
+ static constexpr float freq_base = 10000.0f;
+ static constexpr float freq_scale = 1.0f;
+ static constexpr float ext_factor = 0.0f;
+ static constexpr float attn_factor = 1.0f;
+ static constexpr float beta_fast = 32.0f;
+ static constexpr float beta_slow = 1.0f;
+
+ std::string op_desc(ggml_tensor * t) override {
+ GGML_UNUSED(t);
+ return "FALCON";
+ }
+
+ std::string vars() override {
+ auto n_tokens = hp.n_tokens;
+ return VARS_TO_STR1(n_tokens);
+ }
+
+ double max_nmse_err() override {
+ return 2e-3;
+ }
+
+ test_falcon(int n_tokens = 1)
+ : test_llm({
+ /*n_vocab =*/ 32000,
+ /*n_embd =*/ 3200,
+ /*n_head =*/ 50,
+ /*n_head_kv =*/ 1,
+ /*n_rot =*/ 64,
+ /*n_embd_head =*/ 64,
+ /*n_ff =*/ 8640,
+ /*f_norm_eps =*/ 1e-5f,
+ /*f_norm_rms_eps =*/ 0.f,
+ /*n_tokens =*/ n_tokens,
+ }) {
+ }
+
+ ggml_tensor * build_graph(ggml_context * ctx) override {
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
+
+ inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hp.n_embd, hp.n_tokens);
+
+ // inp_pos - contains the positions
+ struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);
+
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+ struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1);
+
+ ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
+ ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
+
+ for (uint32_t il = 0; il < hp.n_layer; ++il) {
+ // norm
+ ggml_tensor * attn_norm_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
+ ggml_tensor * attn_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
+ ggml_tensor * attn_norm = llm_build_norm(ctx, inpL, attn_norm_w, attn_norm_b, LLM_NORM);
+
+ // self-attention
+ {
+ cur = attn_norm;
+
+ ggml_tensor * wqkv = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd + 2*hp.n_embd_gqa());
+
+ cur = ggml_mul_mat(ctx, wqkv, cur);
+
+ struct ggml_tensor * Qcur = ggml_cont(ctx, ggml_view_2d(ctx, cur, hp.n_embd, hp.n_tokens, cur->nb[1], 0*sizeof(float)*(hp.n_embd)));
+ struct ggml_tensor * Kcur = ggml_cont(ctx, ggml_view_2d(ctx, cur, hp.n_embd_gqa(), hp.n_tokens, cur->nb[1], 1*sizeof(float)*(hp.n_embd)));
+ struct ggml_tensor * Vcur = ggml_cont(ctx, ggml_view_2d(ctx, cur, hp.n_embd_gqa(), hp.n_tokens, cur->nb[1], 1*sizeof(float)*(hp.n_embd + hp.n_embd_gqa())));
+
+ Qcur = ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head, hp.n_tokens);
+ Kcur = ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens);
+
+ // using mode = 2 for neox mode
+ Qcur = ggml_rope_custom(
+ ctx, Qcur, inp_pos, hp.n_rot, 2, 0, hp.n_orig_ctx,
+ freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+ );
+
+ Kcur = ggml_rope_custom(
+ ctx, Kcur, inp_pos, hp.n_rot, 2, 0, hp.n_orig_ctx,
+ freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+ );
+
+ llm_build_kv_store(ctx, k_l, v_l, Kcur, Vcur);
+
+ cur = llm_build_kqv(ctx, k_l, v_l, Qcur, KQ_mask, 1.0f/sqrtf(float(hp.n_embd_head)));
+ }
+
+ struct ggml_tensor * ffn_inp = cur;
+
+ // feed forward
+ {
+ ggml_tensor * ffn_up = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_ff);
+ ggml_tensor * ffn_down = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_ff, hp.n_embd);
+ cur = attn_norm;
+ cur = ggml_mul_mat(ctx, ffn_up, cur);
+ cur = ggml_gelu(ctx, cur);
+ cur = ggml_mul_mat(ctx, ffn_down, cur);
+ }
+
+ cur = ggml_add(ctx, cur, ffn_inp);
+
+ cur = ggml_add(ctx, cur, inpL);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ ggml_tensor * output_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
+ ggml_tensor * output_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
+ cur = llm_build_norm(ctx, cur, output_norm, output_norm_b, LLM_NORM);
+
+ // lm_head
+ ggml_tensor * output = ggml_new_tensor_2d(ctx, GGML_TYPE_Q8_0, hp.n_embd, hp.n_vocab);
+ cur = ggml_mul_mat(ctx, output, cur);
+
+ return cur;
+ }
+};
+
static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
std::vector<std::unique_ptr<test_case>> test_cases;
std::default_random_engine rng(0);
@@ -1626,6 +2029,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
exponent <<= 1;
}
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, 0.1f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, 0.1f, true));
+
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512)); // llama 7B
test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512)); // llama 13B
@@ -1662,6 +2068,14 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
//test_cases.emplace_back(new test_moe(8, 2, 8, 4096, 14336));
#endif
+ // these tests are disabled to save execution time, but they can be handy for debugging
+#if 0
+ test_cases.emplace_back(new test_llama(1));
+ test_cases.emplace_back(new test_llama(2));
+ test_cases.emplace_back(new test_falcon(1));
+ test_cases.emplace_back(new test_falcon(2));
+#endif
+
// run tests
if (mode == MODE_TEST) {
ggml_backend_t backend_cpu = ggml_backend_cpu_init();