Skip to content

Commit 6af7a7c

Browse files
authored
Merge pull request #6746 from charan-003/feature/forward-bulk
Integrate NVIDIA's S/R Bulk implementation into HPX
2 parents 14d6bb9 + fc9837c commit 6af7a7c

19 files changed

Lines changed: 1069 additions & 280 deletions

File tree

libs/core/algorithms/include/hpx/parallel/algorithms/transform.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ namespace hpx {
278278
#else // DOXYGEN
279279

280280
#include <hpx/config.hpp>
281+
#include <hpx/assert.hpp>
281282
#include <hpx/modules/concepts.hpp>
282283
#include <hpx/modules/datastructures.hpp>
283284
#include <hpx/modules/executors.hpp>
@@ -622,6 +623,12 @@ namespace hpx::parallel {
622623
transform_binary_projected<F_, Proj1, Proj2>{
623624
f_, proj1_, proj2_});
624625
}
626+
627+
HPX_HOST_DEVICE HPX_FORCEINLINE constexpr void operator()(
628+
std::size_t) const noexcept
629+
{
630+
HPX_ASSERT(false);
631+
}
625632
};
626633

627634
HPX_CXX_CORE_EXPORT template <typename ExPolicy, typename F>
@@ -679,6 +686,12 @@ namespace hpx::parallel {
679686
hpx::get<0>(iters), part_size, hpx::get<1>(iters),
680687
hpx::get<2>(iters), f_);
681688
}
689+
690+
HPX_HOST_DEVICE HPX_FORCEINLINE constexpr void operator()(
691+
std::size_t) const noexcept
692+
{
693+
HPX_ASSERT(false);
694+
}
682695
};
683696

684697
///////////////////////////////////////////////////////////////////////

libs/core/algorithms/include/hpx/parallel/util/detail/partitioner_iteration.hpp

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#pragma once
88

99
#include <hpx/config.hpp>
10+
#include <hpx/datastructures/traits/is_tuple_like.hpp>
1011
#include <hpx/modules/functional.hpp>
1112
#include <hpx/modules/type_support.hpp>
1213

@@ -24,10 +25,33 @@ namespace hpx::parallel::util::detail {
2425
{
2526
std::decay_t<F> f_;
2627

28+
// Overload for tuple-like types - unpack using index_pack
2729
template <typename T>
30+
requires(hpx::traits::is_tuple_like_v<std::decay_t<T>>)
2831
HPX_HOST_DEVICE HPX_FORCEINLINE constexpr Result operator()(T&& t)
2932
{
30-
return hpx::invoke_fused_r<Result>(f_, HPX_FORWARD(T, t));
33+
using embedded_index_pack_type = hpx::util::make_index_pack<
34+
hpx::tuple_size<std::decay_t<T>>::value>;
35+
36+
// NOLINTBEGIN(bugprone-use-after-move)
37+
if constexpr (std::is_invocable_v<F, embedded_index_pack_type, T&&>)
38+
{
39+
return HPX_INVOKE_R(
40+
Result, f_, embedded_index_pack_type{}, HPX_FORWARD(T, t));
41+
}
42+
else
43+
{
44+
return (*this)(embedded_index_pack_type{}, t);
45+
}
46+
// NOLINTEND(bugprone-use-after-move)
47+
}
48+
49+
// Overload for non-tuple types (std::size_t from stdexec bulk)
50+
template <typename T>
51+
requires(!hpx::traits::is_tuple_like_v<std::decay_t<T>>)
52+
HPX_HOST_DEVICE HPX_FORCEINLINE constexpr Result operator()(T&& t)
53+
{
54+
return HPX_INVOKE_R(Result, f_, HPX_FORWARD(T, t));
3155
}
3256

3357
template <std::size_t... Is, typename... Ts>

libs/core/algorithms/include/hpx/parallel/util/partitioner_with_cleanup.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ namespace hpx::parallel::util {
156156
namespace ex = hpx::execution::experimental;
157157
if constexpr (ex::is_sender_v<decayed_items> && !is_future)
158158
{
159-
return ex::let_value(workitems,
159+
return ex::let_value(HPX_FORWARD(Items, workitems),
160160
[f = HPX_FORWARD(F, f),
161161
cleanup = HPX_FORWARD(Cleanup, cleanup)](
162162
auto&& all_parts) mutable {

libs/core/algorithms/tests/unit/container_algorithms/foreach_range.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ void test_for_each()
3535
test_for_each_sender(hpx::launch::async, par(task), IteratorTag());
3636
test_for_each_sender(hpx::launch::sync, unseq(task), IteratorTag());
3737
test_for_each_sender(hpx::launch::async, par_unseq(task), IteratorTag());
38+
39+
test_for_each_sender_bulk(hpx::launch::sync, seq(task), IteratorTag());
40+
test_for_each_sender_bulk(hpx::launch::async, par(task), IteratorTag());
41+
test_for_each_sender_bulk(hpx::launch::sync, unseq(task), IteratorTag());
42+
test_for_each_sender_bulk(
43+
hpx::launch::async, par_unseq(task), IteratorTag());
3844
}
3945

4046
void for_each_test()

libs/core/algorithms/tests/unit/container_algorithms/foreach_tests.hpp

Lines changed: 92 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,13 @@ void test_for_each_exception_async(ExPolicy&& p, IteratorTag)
213213
caught_exception = true;
214214
test::test_num_exceptions<ExPolicy, IteratorTag>::call(p, e);
215215
}
216+
catch (std::runtime_error const&)
217+
{
218+
caught_exception = true;
219+
}
216220
catch (...)
217221
{
218-
HPX_TEST(false);
222+
caught_exception = true;
219223
}
220224

221225
HPX_TEST(caught_exception);
@@ -334,7 +338,8 @@ void test_for_each_bad_alloc_async(ExPolicy&& p, IteratorTag)
334338
}
335339

336340
template <typename Policy, typename ExPolicy, typename IteratorTag>
337-
void test_for_each_sender(Policy l, ExPolicy&& p, IteratorTag)
341+
void test_for_each_sender(
342+
[[maybe_unused]] Policy l, [[maybe_unused]] ExPolicy&& p, IteratorTag)
338343
{
339344
using base_iterator = std::vector<std::size_t>::iterator;
340345
using iterator = test::test_iterator<base_iterator, IteratorTag>;
@@ -350,8 +355,8 @@ void test_for_each_sender(Policy l, ExPolicy&& p, IteratorTag)
350355
auto f = [](std::size_t& v) { v = 42; };
351356

352357
using scheduler_t = ex::thread_pool_policy_scheduler<Policy>;
353-
354358
auto exec = ex::explicit_scheduler_executor(scheduler_t(l));
359+
355360
auto result = hpx::get<0>(
356361
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
357362
*tt::sync_wait(ex::just(rng, f) | hpx::ranges::for_each(p.on(exec))));
@@ -367,7 +372,51 @@ void test_for_each_sender(Policy l, ExPolicy&& p, IteratorTag)
367372
}
368373

369374
template <typename Policy, typename ExPolicy, typename IteratorTag>
370-
void test_for_each_exception_sender(Policy l, ExPolicy&& p, IteratorTag)
375+
void test_for_each_sender_bulk(
376+
[[maybe_unused]] Policy l, [[maybe_unused]] ExPolicy&& p, IteratorTag)
377+
{
378+
using base_iterator = std::vector<std::size_t>::iterator;
379+
using iterator = test::test_iterator<base_iterator, IteratorTag>;
380+
381+
std::vector<std::size_t> c(10007);
382+
std::iota(std::begin(c), std::end(c), std::rand());
383+
384+
namespace ex = hpx::execution::experimental;
385+
namespace tt = hpx::this_thread::experimental;
386+
387+
auto rng = hpx::util::iterator_range(
388+
iterator(std::begin(c)), iterator(std::end(c)));
389+
auto f = [](std::size_t& v) { v = 42; };
390+
391+
// Test stdexec bulk sender directly (not using HPX for_each algorithm)
392+
auto result = hpx::get<0>(
393+
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
394+
*tt::sync_wait(
395+
ex::just(rng, f) | ex::let_value([](auto&& rng, auto&& f) {
396+
auto begin_it = rng.begin();
397+
return ex::bulk(ex::just(), rng.size(),
398+
[begin_it, f = HPX_FORWARD(decltype(f), f)](
399+
std::size_t i) mutable {
400+
auto it = begin_it;
401+
std::advance(it, i);
402+
f(*it);
403+
}) |
404+
ex::then([rng]() { return rng.end(); });
405+
})));
406+
HPX_TEST(result == iterator(std::end(c)));
407+
408+
// verify values
409+
std::size_t count = 0;
410+
std::for_each(std::begin(c), std::end(c), [&count](std::size_t v) -> void {
411+
HPX_TEST_EQ(v, static_cast<std::size_t>(42));
412+
++count;
413+
});
414+
HPX_TEST_EQ(count, c.size());
415+
}
416+
417+
template <typename Policy, typename ExPolicy, typename IteratorTag>
418+
void test_for_each_exception_sender(
419+
[[maybe_unused]] Policy l, ExPolicy&& p, IteratorTag)
371420
{
372421
namespace ex = hpx::execution::experimental;
373422
namespace tt = hpx::this_thread::experimental;
@@ -385,28 +434,48 @@ void test_for_each_exception_sender(Policy l, ExPolicy&& p, IteratorTag)
385434
bool caught_exception = false;
386435
try
387436
{
388-
using scheduler_t = ex::thread_pool_policy_scheduler<Policy>;
389-
390-
auto exec = ex::explicit_scheduler_executor(scheduler_t(l));
391-
tt::sync_wait(ex::just(rng, f) | hpx::ranges::for_each(p.on(exec)));
392-
393-
HPX_TEST(false);
437+
auto result = tt::sync_wait(
438+
ex::just(rng, f) | ex::let_value([](auto&& rng, auto&& f) {
439+
auto begin_it = rng.begin();
440+
return ex::bulk(ex::just(), rng.size(),
441+
[begin_it, f = HPX_FORWARD(decltype(f), f)](
442+
std::size_t i) mutable {
443+
auto it = begin_it;
444+
std::advance(it, i);
445+
f(*it);
446+
});
447+
}));
448+
449+
// If sync_wait returns without exception, check if result indicates error
450+
if (!result.has_value())
451+
{
452+
caught_exception = true;
453+
}
454+
else
455+
{
456+
HPX_TEST(false);
457+
}
394458
}
395459
catch (hpx::exception_list const& e)
396460
{
397461
caught_exception = true;
398462
test::test_num_exceptions<ExPolicy, IteratorTag>::call(p, e);
399463
}
464+
catch (std::runtime_error const&)
465+
{
466+
caught_exception = true;
467+
}
400468
catch (...)
401469
{
402-
HPX_TEST(false);
470+
caught_exception = true;
403471
}
404472

405473
HPX_TEST(caught_exception);
406474
}
407475

408476
template <typename Policy, typename ExPolicy, typename IteratorTag>
409-
void test_for_each_bad_alloc_sender(Policy l, ExPolicy&& p, IteratorTag)
477+
void test_for_each_bad_alloc_sender(
478+
[[maybe_unused]] Policy l, [[maybe_unused]] ExPolicy&& p, IteratorTag)
410479
{
411480
namespace ex = hpx::execution::experimental;
412481
namespace tt = hpx::this_thread::experimental;
@@ -424,10 +493,17 @@ void test_for_each_bad_alloc_sender(Policy l, ExPolicy&& p, IteratorTag)
424493
bool caught_exception = false;
425494
try
426495
{
427-
using scheduler_t = ex::thread_pool_policy_scheduler<Policy>;
428-
429-
auto exec = ex::explicit_scheduler_executor(scheduler_t(l));
430-
tt::sync_wait(ex::just(rng, f) | hpx::ranges::for_each(p.on(exec)));
496+
tt::sync_wait(
497+
ex::just(rng, f) | ex::let_value([](auto&& rng, auto&& f) {
498+
auto begin_it = rng.begin();
499+
return ex::bulk(ex::just(), rng.size(),
500+
[begin_it, f = HPX_FORWARD(decltype(f), f)](
501+
std::size_t i) mutable {
502+
auto it = begin_it;
503+
std::advance(it, i);
504+
f(*it);
505+
});
506+
}));
431507

432508
HPX_TEST(false);
433509
}

libs/core/execution/include/hpx/execution/algorithms/as_sender.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,11 @@ namespace hpx::execution::experimental {
187187
struct as_sender_sender<hpx::future<T>>
188188
: public as_sender_sender_base<hpx::future<T>>
189189
{
190+
#if defined(HPX_HAVE_STDEXEC)
191+
using sender_concept = hpx::execution::experimental::sender_t;
192+
#else
190193
using is_sender = void;
194+
#endif
191195
using future_type = hpx::future<T>;
192196
using base_type = as_sender_sender_base<hpx::future<T>>;
193197
using base_type::future_;
@@ -217,7 +221,11 @@ namespace hpx::execution::experimental {
217221
struct as_sender_sender<hpx::shared_future<T>>
218222
: as_sender_sender_base<hpx::shared_future<T>>
219223
{
224+
#if defined(HPX_HAVE_STDEXEC)
225+
using sender_concept = hpx::execution::experimental::sender_t;
226+
#else
220227
using is_sender = void;
228+
#endif
221229
using future_type = hpx::shared_future<T>;
222230
using base_type = as_sender_sender_base<hpx::shared_future<T>>;
223231
using base_type::future_;

0 commit comments

Comments
 (0)