diff options
author | Calvin Laurenson <calvin@laurenson.dev> | 2024-06-16 15:23:04 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-17 00:23:04 +0200 |
commit | 43b35e38ba371f9a7faa6dca4c5d1e8f698ffd87 (patch) | |
tree | 11f250899027f3249c9ee15ffaff2048c9b81268 /tests/test-backend-ops.cpp | |
parent | 19b7a836f6658e18e973af532a5cc6ad6b3a27f8 (diff) |
Add support for sqrt on CUDA (#7953)
* cuda sqrt support
* enable cuda in pca
* fix comments in pca
* add test
* add sqrt to ggml_backend_cuda_supports_op
* fix test
* new line
* Use F32 sqrtf instead of F64 sqrt
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Diffstat (limited to 'tests/test-backend-ops.cpp')
-rw-r--r-- | tests/test-backend-ops.cpp | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2b48e623..7c504e93 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1063,6 +1063,33 @@ struct test_sqr : public test_case { } }; +// GGML_OP_SQRT +struct test_sqrt : public test_case { + const ggml_type type; + const std::array<int64_t, 4> ne; + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_sqrt(ggml_type type = GGML_TYPE_F32, + std::array<int64_t, 4> ne = {10, 10, 10, 10}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_tensor * out = ggml_sqrt(ctx, a); + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + // fill with positive values + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + init_tensor_uniform(t, 0.0f, 100.0f); + } + } +}; + // GGML_OP_CLAMP struct test_clamp : public test_case { const ggml_type type; @@ -2200,6 +2227,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } test_cases.emplace_back(new test_sqr()); + test_cases.emplace_back(new test_sqrt()); test_cases.emplace_back(new test_clamp()); test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5)); |