summaryrefslogtreecommitdiff
path: root/tests/test-backend-ops.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test-backend-ops.cpp')
-rw-r--r--tests/test-backend-ops.cpp46
1 files changed, 46 insertions, 0 deletions
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index d4cea805..8a6999f2 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -1412,6 +1412,50 @@ struct test_pad : public test_case {
}
};
+// GGML_OP_ARANGE
+struct test_arange : public test_case {
+ const ggml_type type;
+ const float start;
+ const float stop;
+ const float step;
+
+ std::string vars() override {
+ return VARS_TO_STR4(type, start, stop, step);
+ }
+
+ test_arange(ggml_type type = GGML_TYPE_F32,
+ float start = 0.f, float stop = 10.f, float step = 1.f)
+ : type(type), start(start), stop(stop), step(step) {}
+
+ ggml_tensor * build_graph(ggml_context * ctx) override {
+ ggml_tensor * out = ggml_arange(ctx, start, stop, step);
+ return out;
+ }
+};
+
+// GGML_OP_TIMESTEP_EMBEDDING
+struct test_timestep_embedding : public test_case {
+ const ggml_type type;
+ const std::array<int64_t, 4> ne_a;
+ const int dim;
+ const int max_period;
+
+ std::string vars() override {
+ return VARS_TO_STR4(type, ne_a, dim, max_period);
+ }
+
+ test_timestep_embedding(ggml_type type = GGML_TYPE_F32,
+ std::array<int64_t, 4> ne_a = {2, 1, 1, 1},
+ int dim = 320, int max_period=10000)
+ : type(type), ne_a(ne_a), dim(dim), max_period(max_period) {}
+
+ ggml_tensor * build_graph(ggml_context * ctx) override {
+ ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+ ggml_tensor * out = ggml_timestep_embedding(ctx, a, dim, max_period);
+ return out;
+ }
+};
+
// GGML_OP_LEAKY_RELU
struct test_leaky_relu : public test_case {
const ggml_type type;
@@ -2126,6 +2170,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_group_norm());
test_cases.emplace_back(new test_acc());
test_cases.emplace_back(new test_pad());
+ test_cases.emplace_back(new test_arange());
+ test_cases.emplace_back(new test_timestep_embedding());
test_cases.emplace_back(new test_leaky_relu());
// these tests are disabled to save execution time, but they can be handy for debugging