summaryrefslogtreecommitdiff
path: root/tests/test-backend-ops.cpp
diff options
context:
space:
mode:
authorCalvin Laurenson <calvin@laurenson.dev>2024-06-16 15:23:04 -0700
committerGitHub <noreply@github.com>2024-06-17 00:23:04 +0200
commit43b35e38ba371f9a7faa6dca4c5d1e8f698ffd87 (patch)
tree11f250899027f3249c9ee15ffaff2048c9b81268 /tests/test-backend-ops.cpp
parent19b7a836f6658e18e973af532a5cc6ad6b3a27f8 (diff)
Add support for sqrt on CUDA (#7953)
* cuda sqrt support * enable cuda in pca * fix comments in pca * add test * add sqrt to ggml_backend_cuda_supports_op * fix test * new line * Use F32 sqrtf instead of F64 sqrt Co-authored-by: Johannes Gäßler <johannesg@5d6.de> --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Diffstat (limited to 'tests/test-backend-ops.cpp')
-rw-r--r--tests/test-backend-ops.cpp28
1 files changed, 28 insertions, 0 deletions
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 2b48e623..7c504e93 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -1063,6 +1063,33 @@ struct test_sqr : public test_case {
}
};
+// GGML_OP_SQRT
+struct test_sqrt : public test_case {
+ const ggml_type type;
+ const std::array<int64_t, 4> ne;
+
+ std::string vars() override {
+ return VARS_TO_STR2(type, ne);
+ }
+
+ test_sqrt(ggml_type type = GGML_TYPE_F32,
+ std::array<int64_t, 4> ne = {10, 10, 10, 10})
+ : type(type), ne(ne) {}
+
+ ggml_tensor * build_graph(ggml_context * ctx) override {
+ ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+ ggml_tensor * out = ggml_sqrt(ctx, a);
+ return out;
+ }
+
+ void initialize_tensors(ggml_context * ctx) override {
+ // fill with positive values
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+ init_tensor_uniform(t, 0.0f, 100.0f);
+ }
+ }
+};
+
// GGML_OP_CLAMP
struct test_clamp : public test_case {
const ggml_type type;
@@ -2200,6 +2227,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
}
test_cases.emplace_back(new test_sqr());
+ test_cases.emplace_back(new test_sqrt());
test_cases.emplace_back(new test_clamp());
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5));