Skip to content

Commit 72e03ea

Browse files
committed
Address review comments, add tests, benchmarks, and cleanups
Signed-off-by: Abhishek Bansal <abhibansal593@gmail.com>
1 parent 83173e0 commit 72e03ea

6 files changed

Lines changed: 280 additions & 27 deletions

File tree

libs/full/collectives/include/hpx/collectives/barrier.hpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,7 @@ namespace hpx { namespace distributed {
102102
#include <hpx/modules/components_base.hpp>
103103
#include <hpx/modules/errors.hpp>
104104
#include <hpx/modules/futures.hpp>
105-
#include <hpx/modules/memory.hpp>
106105

107-
#include <array>
108106
#include <atomic>
109107
#include <cstddef>
110108
#include <cstdint>
@@ -263,14 +261,19 @@ namespace hpx::collectives {
263261
HPX_GET_EXCEPTION(hpx::error::bad_parameter,
264262
"hpx::collectives::barrier (hierarchical)",
265263
"hierarchical barrier requires an explicit generation "
266-
"number for the 2k/2k+1 internal mapping"));
264+
"number for the 2k-1/2k internal mapping"));
267265
}
268266

269267
if (this_site.is_default())
270268
{
271269
this_site = agas::get_locality_id();
272270
}
273271

272+
if (communicators.size() == 0)
273+
{
274+
return hpx::make_ready_future();
275+
}
276+
274277
generation_arg const reduce_gen(2 * generation - 1);
275278
generation_arg const broadcast_gen(2 * generation);
276279

@@ -287,10 +290,6 @@ namespace hpx::collectives {
287290

288291
// Broadcast phase: walk sub-communicators from shallowest to deepest.
289292
// Returning the final future lets the caller chain on completion.
290-
if (communicators.size() == 0)
291-
{
292-
return hpx::make_ready_future();
293-
}
294293

295294
for (std::size_t i = 0; i + 1 < communicators.size(); ++i)
296295
{
@@ -343,8 +342,8 @@ namespace hpx::distributed {
343342
void detach();
344343

345344
// Get the instance of the global barrier
346-
static std::array<barrier, 2>& get_global_barrier();
347-
static std::array<barrier, 2> create_global_barrier();
345+
static barrier& get_global_barrier();
346+
static barrier create_global_barrier();
348347

349348
static void synchronize();
350349

libs/full/collectives/src/barrier.cpp

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020
#include <hpx/modules/type_support.hpp>
2121

2222
#include <algorithm>
23-
#include <array>
2423
#include <atomic>
2524
#include <cstddef>
2625
#include <iterator>
2726
#include <string>
27+
#include <type_traits>
2828
#include <utility>
2929
#include <variant>
3030
#include <vector>
@@ -227,7 +227,7 @@ namespace hpx::distributed {
227227
{
228228
wait(hpx::launch::async).get();
229229
}
230-
catch (...)
230+
catch (...) // NOLINT(bugprone-empty-catch)
231231
{
232232
}
233233
}
@@ -242,28 +242,26 @@ namespace hpx::distributed {
242242
comm_ = std::monostate{};
243243
}
244244

245-
std::array<barrier, 2> barrier::create_global_barrier()
245+
barrier barrier::create_global_barrier()
246246
{
247247
runtime& rt = get_runtime();
248248
util::runtime_configuration const& cfg = rt.get_config();
249249
auto const num = static_cast<std::size_t>(cfg.get_num_localities());
250250
auto const rank = static_cast<std::size_t>(hpx::get_locality_id());
251-
barrier b1("/0/hpx/global_barrier0", num, rank, force_flat_tag::tag);
252-
barrier b2("/0/hpx/global_barrier1", num, rank, force_flat_tag::tag);
253-
return {{std::move(b1), std::move(b2)}};
251+
return barrier("/0/hpx/global_barrier", num, rank, force_flat_tag::tag);
254252
}
255253

256-
std::array<barrier, 2>& barrier::get_global_barrier()
254+
barrier& barrier::get_global_barrier()
257255
{
258-
static std::array<barrier, 2> bs = {};
259-
return bs;
256+
static barrier b;
257+
return b;
260258
}
261259

262260
void barrier::synchronize()
263261
{
264-
std::array<barrier, 2>& b = get_global_barrier();
265-
HPX_ASSERT(!std::holds_alternative<std::monostate>(b[0].comm_));
266-
b[0].wait();
262+
barrier& b = get_global_barrier();
263+
HPX_ASSERT(!std::holds_alternative<std::monostate>(b.comm_));
264+
b.wait();
267265
}
268266
} // namespace hpx::distributed
269267

libs/full/collectives/tests/performance/benchmark_collectives.cpp

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Copyright (c) 2025 Alexander Strack
33
// Copyright (c) 2025 Hartmut Kaiser
44
// Copyright (c) 2026 Anshuman Agrawal
5+
// Copyright (c) 2026 Abhishek Bansal
56
//
67
// SPDX-License-Identifier: BSL-1.0
78
// Distributed under the Boost Software License, Version 1.0. (See accompanying
@@ -12,6 +13,7 @@
1213
#if !defined(HPX_COMPUTE_DEVICE_CODE)
1314
#include <hpx/collectives/all_gather.hpp>
1415
#include <hpx/collectives/all_reduce.hpp>
16+
#include <hpx/collectives/barrier.hpp>
1517
#include <hpx/hpx.hpp>
1618
#include <hpx/hpx_init.hpp>
1719
#include <hpx/modules/collectives.hpp>
@@ -31,6 +33,7 @@ constexpr char const* broadcast_direct_basename = "/test/broadcast_direct/";
3133
constexpr char const* gather_direct_basename = "/test/gather_direct/";
3234
constexpr char const* all_reduce_direct_basename = "/test/all_reduce_direct/";
3335
constexpr char const* all_gather_direct_basename = "/test/all_gather_direct/";
36+
constexpr char const* barrier_bench_basename = "/test/barrier_bench/";
3437

3538
struct vector_adder
3639
{
@@ -514,6 +517,45 @@ void test_all_reduce_hierarchical(int arity, int lpn, std::size_t iterations,
514517
}
515518
}
516519

520+
void test_barrier_hierarchical(int arity, int lpn, std::size_t iterations,
521+
int test_size, std::string const& operation, int fallback_threshold)
522+
{
523+
std::size_t const num_localities =
524+
hpx::get_num_localities(hpx::launch::sync);
525+
std::size_t const this_locality = hpx::get_locality_id();
526+
HPX_TEST_LTE(static_cast<std::size_t>(2), num_localities);
527+
auto communicators =
528+
create_hierarchical_communicator(barrier_bench_basename,
529+
num_sites_arg(num_localities), this_site_arg(this_locality),
530+
arity_arg(arity), generation_arg(1), root_site_arg(0),
531+
fallback_threshold < 0 ?
532+
flat_fallback_threshold_arg() :
533+
flat_fallback_threshold_arg(
534+
static_cast<std::size_t>(fallback_threshold)));
535+
char const* const barrier_sync_name = "/test/barrier/hierarchical";
536+
hpx::distributed::barrier sync(barrier_sync_name);
537+
std::vector<double> result(iterations, 0.0);
538+
539+
for (std::size_t i = 0; i != iterations; ++i)
540+
{
541+
hpx::chrono::high_resolution_timer const timer;
542+
hpx::collectives::barrier(
543+
communicators, this_site_arg(this_locality), generation_arg(i + 1))
544+
.get();
545+
sync.wait();
546+
result[i] = timer.elapsed();
547+
}
548+
549+
if (this_locality == 0)
550+
{
551+
std::string const mod_name = fallback_threshold < 0 ?
552+
std::string("hierarchical") :
553+
"hierarchical_t" + std::to_string(fallback_threshold);
554+
write_to_file(operation, mod_name, arity, num_localities, lpn,
555+
test_size, iterations, result);
556+
}
557+
}
558+
517559
////////////////////////////////////////////////////////////////////////////////////////
518560
// One shot collectives
519561
void test_one_shot_use_scatter(int lpn, std::size_t iterations, int test_size,
@@ -1116,6 +1158,36 @@ void test_multiple_use_with_generation_all_reduce(int lpn,
11161158
}
11171159
}
11181160

1161+
void test_multiple_use_with_generation_barrier(int lpn, std::size_t iterations,
1162+
int test_size, std::string const& operation)
1163+
{
1164+
std::size_t const num_localities =
1165+
hpx::get_num_localities(hpx::launch::sync);
1166+
std::size_t const this_locality = hpx::get_locality_id();
1167+
HPX_TEST_LTE(static_cast<std::size_t>(2), num_localities);
1168+
auto const barrier_client = create_communicator(barrier_bench_basename,
1169+
num_sites_arg(num_localities), this_site_arg(this_locality));
1170+
char const* const barrier_sync_name = "/test/barrier/generation";
1171+
hpx::distributed::barrier sync(barrier_sync_name);
1172+
std::vector<double> result(iterations, 0.0);
1173+
1174+
for (std::size_t i = 0; i != iterations; ++i)
1175+
{
1176+
hpx::chrono::high_resolution_timer const timer;
1177+
hpx::collectives::barrier(
1178+
barrier_client, this_site_arg(), generation_arg(i + 1))
1179+
.get();
1180+
sync.wait();
1181+
result[i] = timer.elapsed();
1182+
}
1183+
1184+
if (this_locality == 0)
1185+
{
1186+
write_to_file(operation, "multi_use", -1, num_localities, lpn,
1187+
test_size, iterations, result);
1188+
}
1189+
}
1190+
11191191
void test_all_gather_hierarchical(int arity, int lpn, std::size_t iterations,
11201192
int test_size, std::string const& operation, int fallback_threshold)
11211193
{
@@ -1367,6 +1439,19 @@ int hpx_main(hpx::program_options::variables_map& vm)
13671439
operation, fallback_threshold);
13681440
}
13691441
}
1442+
else if (operation == "barrier")
1443+
{
1444+
if (arity == -1)
1445+
{
1446+
test_multiple_use_with_generation_barrier(
1447+
lpn, iterations, test_size, operation);
1448+
}
1449+
else
1450+
{
1451+
test_barrier_hierarchical(arity, lpn, iterations, test_size,
1452+
operation, fallback_threshold);
1453+
}
1454+
}
13701455
}
13711456

13721457
return hpx::finalize();
@@ -1388,7 +1473,7 @@ int main(int argc, char* argv[])
13881473
"Number of Iteration the collective is executed")
13891474
("operation", value<std::string>()->default_value("scatter"),
13901475
"Collective Operation (scatter, reduce, broadcast, gather, "
1391-
"all_reduce, all_gather)")
1476+
"all_reduce, all_gather, barrier)")
13921477
("fallback_threshold", value<int>()->default_value(-1),
13931478
"Flat fallback threshold for hierarchical mode. -1 uses library "
13941479
"default (16). Set to 0 to force tree construction. Only meaningful "

libs/full/collectives/tests/unit/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ if(HPX_WITH_NETWORKING)
2727
${tests}
2828
all_gather_hierarchical
2929
all_reduce_hierarchical
30+
barrier_hierarchical
3031
broadcast_direct
3132
concurrent_collectives
3233
exclusive_scan_

0 commit comments

Comments
 (0)