summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/test-backend-ops.cpp29
1 files changed, 20 insertions, 9 deletions
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index ce406a8a..2b48e623 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -642,20 +642,29 @@ struct test_case {
struct test_unary : public test_case {
const ggml_unary_op op;
const ggml_type type;
- const std::array<int64_t, 4> ne;
+ const std::array<int64_t, 4> ne_a;
+ int v; // view (1 : non-contiguous a)
std::string vars() override {
- return VARS_TO_STR2(type, ne);
+ return VARS_TO_STR3(type, ne_a, v);
}
test_unary(ggml_unary_op op,
ggml_type type = GGML_TYPE_F32,
- std::array<int64_t, 4> ne = {128, 10, 10, 10})
- : op(op), type(type), ne(ne) {}
+ std::array<int64_t, 4> ne_a = {128, 10, 10, 10},
+ int v = 0)
+ : op(op), type(type), ne_a(ne_a), v(v) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
- ggml_tensor * in = ggml_new_tensor(ctx, type, 4, ne.data());
- ggml_tensor * out = ggml_unary(ctx, in, op);
+ ggml_tensor * a;
+ if (v & 1) {
+ auto ne = ne_a; ne[0] *= 3;
+ a = ggml_new_tensor(ctx, type, 4, ne.data());
+ a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
+ } else {
+ a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+ }
+ ggml_tensor * out = ggml_unary(ctx, a, op);
return out;
}
@@ -2016,9 +2025,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
};
// unary ops
- for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
- test_cases.emplace_back(new test_unary((ggml_unary_op) op));
- test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 7, 13, 19, 23 }));
+ for (int v : {0, 1}) {
+ for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
+ test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 128, 10, 10, 10 }, v));
+ test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 7, 13, 19, 23 }, v));
+ }
}
test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false));