summaryrefslogtreecommitdiff
path: root/protocols/Telegram/tdlib/td/tdutils/test/StealingQueue.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'protocols/Telegram/tdlib/td/tdutils/test/StealingQueue.cpp')
-rw-r--r--protocols/Telegram/tdlib/td/tdutils/test/StealingQueue.cpp180
1 files changed, 180 insertions, 0 deletions
diff --git a/protocols/Telegram/tdlib/td/tdutils/test/StealingQueue.cpp b/protocols/Telegram/tdlib/td/tdutils/test/StealingQueue.cpp
new file mode 100644
index 0000000000..453a63179f
--- /dev/null
+++ b/protocols/Telegram/tdlib/td/tdutils/test/StealingQueue.cpp
@@ -0,0 +1,180 @@
+//
+// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022
+//
+// Distributed under the Boost Software License, Version 1.0. (See accompanying
+// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
+//
+#include "td/utils/AtomicRead.h"
+#include "td/utils/benchmark.h"
+#include "td/utils/common.h"
+#include "td/utils/logging.h"
+#include "td/utils/MpmcQueue.h"
+#include "td/utils/port/thread.h"
+#include "td/utils/Random.h"
+#include "td/utils/SliceBuilder.h"
+#include "td/utils/StealingQueue.h"
+#include "td/utils/tests.h"
+
+#include <atomic>
+#include <cstring>
+
+TEST(StealingQueue, very_simple) {
+ td::StealingQueue<int, 8> q;
+ q.local_push(1, [](auto x) { UNREACHABLE(); });
+ int x;
+ CHECK(q.local_pop(x));
+ ASSERT_EQ(1, x);
+}
+
+#if !TD_THREAD_UNSUPPORTED
+TEST(AtomicRead, simple) {
+ td::Stage run;
+ td::Stage check;
+
+ std::size_t threads_n = 10;
+ td::vector<td::thread> threads;
+
+ int x{0};
+ std::atomic<int> version{0};
+
+ td::int64 res = 0;
+ for (std::size_t i = 0; i < threads_n; i++) {
+ threads.emplace_back([&, id = static_cast<td::uint32>(i)] {
+ for (td::uint64 round = 1; round < 10000; round++) {
+ run.wait(round * threads_n);
+ if (id == 0) {
+ version++;
+ x++;
+ version++;
+ } else {
+ int y = 0;
+ auto v1 = version.load();
+ y = x;
+ auto v2 = version.load();
+ if (v1 == v2 && v1 % 2 == 0) {
+ res += y;
+ }
+ }
+
+ check.wait(round * threads_n);
+ }
+ });
+ }
+ td::do_not_optimize_away(res);
+ for (auto &thread : threads) {
+ thread.join();
+ }
+}
+
+TEST(AtomicRead, simple2) {
+ td::Stage run;
+ td::Stage check;
+
+ std::size_t threads_n = 10;
+ td::vector<td::thread> threads;
+
+ struct Value {
+ td::uint64 value = 0;
+ char str[50] = "0 0 0 0";
+ };
+ td::AtomicRead<Value> value;
+
+ auto to_str = [](td::uint64 i) {
+ return PSTRING() << i << " " << i << " " << i << " " << i;
+ };
+ for (std::size_t i = 0; i < threads_n; i++) {
+ threads.emplace_back([&, id = static_cast<td::uint32>(i)] {
+ for (td::uint64 round = 1; round < 10000; round++) {
+ run.wait(round * threads_n);
+ if (id == 0) {
+ auto x = value.lock();
+ x->value = round;
+ auto str = to_str(round);
+ std::memcpy(x->str, str.c_str(), str.size() + 1);
+ } else {
+ Value x;
+ value.read(x);
+ LOG_CHECK(x.value == round || x.value == round - 1) << x.value << " " << round;
+ CHECK(x.str == to_str(x.value));
+ }
+ check.wait(round * threads_n);
+ }
+ });
+ }
+ for (auto &thread : threads) {
+ thread.join();
+ }
+}
+
+TEST(StealingQueue, simple) {
+ td::uint64 sum = 0;
+ std::atomic<td::uint64> got_sum{0};
+
+ td::Stage run;
+ td::Stage check;
+
+ std::size_t threads_n = 10;
+ td::vector<td::thread> threads;
+ td::vector<td::StealingQueue<int, 8>> lq(threads_n);
+ td::MpmcQueue<int> gq(threads_n);
+
+ constexpr td::uint64 XN = 20;
+ td::uint64 x_sum[XN];
+ x_sum[0] = 0;
+ x_sum[1] = 1;
+ for (td::uint64 i = 2; i < XN; i++) {
+ x_sum[i] = i + x_sum[i - 1] + x_sum[i - 2];
+ }
+
+ td::Random::Xorshift128plus rnd(123);
+ for (std::size_t i = 0; i < threads_n; i++) {
+ threads.emplace_back([&, id = static_cast<td::uint32>(i)] {
+ for (td::uint64 round = 1; round < 1000; round++) {
+ if (id == 0) {
+ sum = 0;
+ auto n = static_cast<int>(rnd() % 5);
+ for (int j = 0; j < n; j++) {
+ auto x = static_cast<int>(rnd() % XN);
+ sum += x_sum[x];
+ gq.push(x, id);
+ }
+ got_sum = 0;
+ }
+ run.wait(round * threads_n);
+ while (got_sum.load() != sum) {
+ auto x = [&] {
+ int res;
+ if (lq[id].local_pop(res)) {
+ return res;
+ }
+ if (gq.try_pop(res, id)) {
+ return res;
+ }
+ if (lq[id].steal(res, lq[static_cast<size_t>(rnd()) % threads_n])) {
+ //LOG(ERROR) << "STEAL";
+ return res;
+ }
+ return 0;
+ }();
+ if (x == 0) {
+ continue;
+ }
+ //LOG(ERROR) << x << " " << got_sum.load() << " " << sum;
+ got_sum.fetch_add(x, std::memory_order_relaxed);
+ lq[id].local_push(x - 1, [&](auto y) {
+ //LOG(ERROR) << "OVERFLOW";
+ gq.push(y, id);
+ });
+ if (x > 1) {
+ lq[id].local_push(x - 2, [&](auto y) { gq.push(y, id); });
+ }
+ }
+ check.wait(round * threads_n);
+ }
+ });
+ }
+ for (auto &thread : threads) {
+ thread.join();
+ }
+}
+#endif