diff options
Diffstat (limited to 'tests/test-backend-ops.cpp')
-rw-r--r-- | tests/test-backend-ops.cpp | 28 |
1 files changed, 17 insertions, 11 deletions
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index de74585d..b200cccc 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1259,22 +1259,26 @@ struct test_im2col : public test_case { // GGML_OP_CONCAT struct test_concat : public test_case { const ggml_type type; - const std::array<int64_t, 4> ne; - const int64_t b_ne2; + const std::array<int64_t, 4> ne_a; + const int64_t ne_b_d; + const int dim; std::string vars() override { - return VARS_TO_STR3(type, ne, b_ne2); + return VARS_TO_STR4(type, ne_a, ne_b_d, dim); } test_concat(ggml_type type = GGML_TYPE_F32, - std::array<int64_t, 4> ne = {10, 10, 10, 10}, - int64_t b_ne2 = 10) - : type(type), ne(ne), b_ne2(b_ne2) {} + std::array<int64_t, 4> ne_a = {10, 10, 10, 10}, + int64_t ne_b_d = 10, + int dim = 2) + : type(type), ne_a(ne_a), ne_b_d(ne_b_d), dim(dim) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); - ggml_tensor * b = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], b_ne2, ne[3]); - ggml_tensor * out = ggml_concat(ctx, a, b); + auto ne_b = ne_a; + ne_b[dim] = ne_b_d; + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data()); + ggml_tensor * out = ggml_concat(ctx, a, b, dim); return out; } }; @@ -2211,8 +2215,10 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } - test_cases.emplace_back(new test_concat(GGML_TYPE_F32)); - test_cases.emplace_back(new test_concat(GGML_TYPE_I32)); + for (int dim : { 0, 1, 2, 3, }) { + test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim)); + test_cases.emplace_back(new test_concat(GGML_TYPE_I32, {11, 12, 13, 14}, 7, dim)); + } for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) { test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order)); |