diff options
author | Johannes Gäßler <johannesg@5d6.de> | 2024-01-09 08:58:55 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-09 08:58:55 +0100 |
commit | 8f900abfc09851e281bc9027e0ab2f16bf079b29 (patch) | |
tree | 08f27d1c3a182663c28b7f36aee767d27cadbab6 /tests | |
parent | 1fc2f265ff9377a37fd2c61eae9cd813a3491bea (diff) |
CUDA: faster softmax via shared memory + fp16 math (#4742)
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test-backend-ops.cpp | 17 |
1 files changed, 15 insertions, 2 deletions
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b79de7a7..7a60d774 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -450,7 +450,7 @@ struct test_case { double err = nmse(f1.data(), f2.data(), f1.size()); if (err > ud->max_err) { - printf("[%s] NMSE = %f ", ggml_op_desc(t1), err); + printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err); //for (int i = 0; i < (int) f1.size(); i++) { // printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]); //} @@ -1449,6 +1449,7 @@ struct test_moe : public test_case { static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) { std::vector<std::unique_ptr<test_case>> test_cases; + std::default_random_engine rng(0); const ggml_type all_types[] = { GGML_TYPE_F32, GGML_TYPE_F16, @@ -1583,7 +1584,19 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 1}, 5)); test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 10}, 5)); - test_cases.emplace_back(new test_soft_max()); + std::uniform_int_distribution<> dist_ne1(1, 50); + int exponent = 1; + while (exponent < (1 << 17)) { + std::uniform_int_distribution<> dist_ne0(exponent, 2*exponent); + + for (int n = 0; n < 10; ++n) { + int64_t ne0 = dist_ne0(rng); + int64_t ne1 = dist_ne1(rng); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1})); + } + + exponent <<= 1; + } for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512)); // llama 7B |