diff options
Diffstat (limited to 'tests/test-backend-ops.cpp')
-rw-r--r-- | tests/test-backend-ops.cpp | 32 |
1 files changed, 29 insertions, 3 deletions
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index f080f7e2..85ef21c2 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1329,23 +1329,47 @@ struct test_upscale : public test_case { const ggml_type type; const std::array<int64_t, 4> ne; const int32_t scale_factor; + const bool transpose; std::string vars() override { - return VARS_TO_STR3(type, ne, scale_factor); + return VARS_TO_STR4(type, ne, scale_factor, transpose); } test_upscale(ggml_type type = GGML_TYPE_F32, std::array<int64_t, 4> ne = {512, 512, 3, 1}, - int32_t scale_factor = 2) - : type(type), ne(ne), scale_factor(scale_factor) {} + int32_t scale_factor = 2, bool transpose = false) + : type(type), ne(ne), scale_factor(scale_factor), transpose(transpose) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + if (transpose) a = ggml_transpose(ctx, a); ggml_tensor * out = ggml_upscale(ctx, a, scale_factor); return out; } }; +// GGML_OP_UPSCALE (ext) +struct test_upscale_ext : public test_case { + const ggml_type type; + const std::array<int64_t, 4> ne; + const std::array<int64_t, 4> ne_tgt; + + std::string vars() override { + return VARS_TO_STR3(type, ne, ne_tgt); + } + + test_upscale_ext(ggml_type type = GGML_TYPE_F32, + std::array<int64_t, 4> ne = {2, 5, 7, 11}, + std::array<int64_t, 4> ne_tgt = {5, 7, 11, 13}) + : type(type), ne(ne), ne_tgt(ne_tgt) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_tensor * out = ggml_upscale_ext(ctx, a, ne_tgt[0], ne_tgt[1],ne_tgt[2], ne_tgt[3]); + return out; + } +}; + // GGML_OP_GROUP_NORM struct test_group_norm : public test_case { const ggml_type type; @@ -2169,6 +2193,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_sum_rows()); test_cases.emplace_back(new test_upscale()); + test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, { 512, 512, 3, 1 }, 2, true)); + test_cases.emplace_back(new test_upscale_ext()); test_cases.emplace_back(new test_group_norm()); test_cases.emplace_back(new test_acc()); test_cases.emplace_back(new test_pad()); |