22// Copyright (c) 2025 Alexander Strack
33// Copyright (c) 2025 Hartmut Kaiser
44// Copyright (c) 2026 Anshuman Agrawal
5+ // Copyright (c) 2026 Abhishek Bansal
56//
67// SPDX-License-Identifier: BSL-1.0
78// Distributed under the Boost Software License, Version 1.0. (See accompanying
1213#if !defined(HPX_COMPUTE_DEVICE_CODE)
1314#include < hpx/collectives/all_gather.hpp>
1415#include < hpx/collectives/all_reduce.hpp>
16+ #include < hpx/collectives/barrier.hpp>
1517#include < hpx/hpx.hpp>
1618#include < hpx/hpx_init.hpp>
1719#include < hpx/modules/collectives.hpp>
@@ -31,6 +33,7 @@ constexpr char const* broadcast_direct_basename = "/test/broadcast_direct/";
3133constexpr char const * gather_direct_basename = " /test/gather_direct/" ;
3234constexpr char const * all_reduce_direct_basename = " /test/all_reduce_direct/" ;
3335constexpr char const * all_gather_direct_basename = " /test/all_gather_direct/" ;
36+ constexpr char const * barrier_bench_basename = " /test/barrier_bench/" ;
3437
3538struct vector_adder
3639{
@@ -514,6 +517,45 @@ void test_all_reduce_hierarchical(int arity, int lpn, std::size_t iterations,
514517 }
515518}
516519
520+ void test_barrier_hierarchical (int arity, int lpn, std::size_t iterations,
521+ int test_size, std::string const & operation, int fallback_threshold)
522+ {
523+ std::size_t const num_localities =
524+ hpx::get_num_localities (hpx::launch::sync);
525+ std::size_t const this_locality = hpx::get_locality_id ();
526+ HPX_TEST_LTE (static_cast <std::size_t >(2 ), num_localities);
527+ auto communicators =
528+ create_hierarchical_communicator (barrier_bench_basename,
529+ num_sites_arg (num_localities), this_site_arg (this_locality),
530+ arity_arg (arity), generation_arg (1 ), root_site_arg (0 ),
531+ fallback_threshold < 0 ?
532+ flat_fallback_threshold_arg () :
533+ flat_fallback_threshold_arg (
534+ static_cast <std::size_t >(fallback_threshold)));
535+ char const * const barrier_sync_name = " /test/barrier/hierarchical" ;
536+ hpx::distributed::barrier sync (barrier_sync_name);
537+ std::vector<double > result (iterations, 0.0 );
538+
539+ for (std::size_t i = 0 ; i != iterations; ++i)
540+ {
541+ hpx::chrono::high_resolution_timer const timer;
542+ hpx::collectives::barrier (
543+ communicators, this_site_arg (this_locality), generation_arg (i + 1 ))
544+ .get ();
545+ sync.wait ();
546+ result[i] = timer.elapsed ();
547+ }
548+
549+ if (this_locality == 0 )
550+ {
551+ std::string const mod_name = fallback_threshold < 0 ?
552+ std::string (" hierarchical" ) :
553+ " hierarchical_t" + std::to_string (fallback_threshold);
554+ write_to_file (operation, mod_name, arity, num_localities, lpn,
555+ test_size, iterations, result);
556+ }
557+ }
558+
517559// //////////////////////////////////////////////////////////////////////////////////////
518560// One shot collectives
519561void test_one_shot_use_scatter (int lpn, std::size_t iterations, int test_size,
@@ -1116,6 +1158,36 @@ void test_multiple_use_with_generation_all_reduce(int lpn,
11161158 }
11171159}
11181160
1161+ void test_multiple_use_with_generation_barrier (int lpn, std::size_t iterations,
1162+ int test_size, std::string const & operation)
1163+ {
1164+ std::size_t const num_localities =
1165+ hpx::get_num_localities (hpx::launch::sync);
1166+ std::size_t const this_locality = hpx::get_locality_id ();
1167+ HPX_TEST_LTE (static_cast <std::size_t >(2 ), num_localities);
1168+ auto const barrier_client = create_communicator (barrier_bench_basename,
1169+ num_sites_arg (num_localities), this_site_arg (this_locality));
1170+ char const * const barrier_sync_name = " /test/barrier/generation" ;
1171+ hpx::distributed::barrier sync (barrier_sync_name);
1172+ std::vector<double > result (iterations, 0.0 );
1173+
1174+ for (std::size_t i = 0 ; i != iterations; ++i)
1175+ {
1176+ hpx::chrono::high_resolution_timer const timer;
1177+ hpx::collectives::barrier (
1178+ barrier_client, this_site_arg (), generation_arg (i + 1 ))
1179+ .get ();
1180+ sync.wait ();
1181+ result[i] = timer.elapsed ();
1182+ }
1183+
1184+ if (this_locality == 0 )
1185+ {
1186+ write_to_file (operation, " multi_use" , -1 , num_localities, lpn,
1187+ test_size, iterations, result);
1188+ }
1189+ }
1190+
11191191void test_all_gather_hierarchical (int arity, int lpn, std::size_t iterations,
11201192 int test_size, std::string const & operation, int fallback_threshold)
11211193{
@@ -1367,6 +1439,19 @@ int hpx_main(hpx::program_options::variables_map& vm)
13671439 operation, fallback_threshold);
13681440 }
13691441 }
1442+ else if (operation == " barrier" )
1443+ {
1444+ if (arity == -1 )
1445+ {
1446+ test_multiple_use_with_generation_barrier (
1447+ lpn, iterations, test_size, operation);
1448+ }
1449+ else
1450+ {
1451+ test_barrier_hierarchical (arity, lpn, iterations, test_size,
1452+ operation, fallback_threshold);
1453+ }
1454+ }
13701455 }
13711456
13721457 return hpx::finalize ();
@@ -1388,7 +1473,7 @@ int main(int argc, char* argv[])
13881473 " Number of Iteration the collective is executed" )
13891474 (" operation" , value<std::string>()->default_value (" scatter" ),
13901475 " Collective Operation (scatter, reduce, broadcast, gather, "
1391- " all_reduce, all_gather)" )
1476+ " all_reduce, all_gather, barrier )" )
13921477 (" fallback_threshold" , value<int >()->default_value (-1 ),
13931478 " Flat fallback threshold for hierarchical mode. -1 uses library "
13941479 " default (16). Set to 0 to force tree construction. Only meaningful "
0 commit comments