diff options
Diffstat (limited to 'tests/test-backend-ops.cpp')
-rw-r--r-- | tests/test-backend-ops.cpp | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index d4cea805..8a6999f2 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1412,6 +1412,50 @@ struct test_pad : public test_case { } }; +// GGML_OP_ARANGE +struct test_arange : public test_case { + const ggml_type type; + const float start; + const float stop; + const float step; + + std::string vars() override { + return VARS_TO_STR4(type, start, stop, step); + } + + test_arange(ggml_type type = GGML_TYPE_F32, + float start = 0.f, float stop = 10.f, float step = 1.f) + : type(type), start(start), stop(stop), step(step) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * out = ggml_arange(ctx, start, stop, step); + return out; + } +}; + +// GGML_OP_TIMESTEP_EMBEDDING +struct test_timestep_embedding : public test_case { + const ggml_type type; + const std::array<int64_t, 4> ne_a; + const int dim; + const int max_period; + + std::string vars() override { + return VARS_TO_STR4(type, ne_a, dim, max_period); + } + + test_timestep_embedding(ggml_type type = GGML_TYPE_F32, + std::array<int64_t, 4> ne_a = {2, 1, 1, 1}, + int dim = 320, int max_period=10000) + : type(type), ne_a(ne_a), dim(dim), max_period(max_period) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_tensor * out = ggml_timestep_embedding(ctx, a, dim, max_period); + return out; + } +}; + // GGML_OP_LEAKY_RELU struct test_leaky_relu : public test_case { const ggml_type type; @@ -2126,6 +2170,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_group_norm()); test_cases.emplace_back(new test_acc()); test_cases.emplace_back(new test_pad()); + test_cases.emplace_back(new test_arange()); + test_cases.emplace_back(new test_timestep_embedding()); test_cases.emplace_back(new test_leaky_relu()); // these tests are disabled to save execution time, but they can be handy for debugging |