summaryrefslogtreecommitdiff
path: root/tests/test-backend-ops.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test-backend-ops.cpp')
-rw-r--r--tests/test-backend-ops.cpp32
1 files changed, 29 insertions, 3 deletions
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index f080f7e2..85ef21c2 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -1329,23 +1329,47 @@ struct test_upscale : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
const int32_t scale_factor;
+ const bool transpose;
std::string vars() override {
- return VARS_TO_STR3(type, ne, scale_factor);
+ return VARS_TO_STR4(type, ne, scale_factor, transpose);
}
test_upscale(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {512, 512, 3, 1},
- int32_t scale_factor = 2)
- : type(type), ne(ne), scale_factor(scale_factor) {}
+ int32_t scale_factor = 2, bool transpose = false)
+ : type(type), ne(ne), scale_factor(scale_factor), transpose(transpose) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+ if (transpose) a = ggml_transpose(ctx, a);
ggml_tensor * out = ggml_upscale(ctx, a, scale_factor);
return out;
}
};
+// GGML_OP_UPSCALE (ext)
+struct test_upscale_ext : public test_case {
+ const ggml_type type;
+ const std::array<int64_t, 4> ne;
+ const std::array<int64_t, 4> ne_tgt;
+
+ std::string vars() override {
+ return VARS_TO_STR3(type, ne, ne_tgt);
+ }
+
+ test_upscale_ext(ggml_type type = GGML_TYPE_F32,
+ std::array<int64_t, 4> ne = {2, 5, 7, 11},
+ std::array<int64_t, 4> ne_tgt = {5, 7, 11, 13})
+ : type(type), ne(ne), ne_tgt(ne_tgt) {}
+
+ ggml_tensor * build_graph(ggml_context * ctx) override {
+ ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+ ggml_tensor * out = ggml_upscale_ext(ctx, a, ne_tgt[0], ne_tgt[1],ne_tgt[2], ne_tgt[3]);
+ return out;
+ }
+};
+
// GGML_OP_GROUP_NORM
struct test_group_norm : public test_case {
const ggml_type type;
@@ -2169,6 +2193,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_sum_rows());
test_cases.emplace_back(new test_upscale());
+ test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, { 512, 512, 3, 1 }, 2, true));
+ test_cases.emplace_back(new test_upscale_ext());
test_cases.emplace_back(new test_group_norm());
test_cases.emplace_back(new test_acc());
test_cases.emplace_back(new test_pad());