Skip to content

Commit 092d4ba

Browse files
committed
use stdexec in foreach
1 parent c90dc43 commit 092d4ba

5 files changed

Lines changed: 103 additions & 37 deletions

File tree

libs/core/algorithms/include/hpx/parallel/algorithms/for_each.hpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -343,11 +343,9 @@ namespace hpx::parallel {
343343
}
344344

345345
template <typename F_, typename Proj_>
346-
HPX_HOST_DEVICE for_each_iteration(F_&& f, Proj_&& proj)
346+
HPX_HOST_DEVICE for_each_iteration(F_&& f, Proj_&&)
347347
: f_(HPX_FORWARD(F_, f))
348348
{
349-
// proj parameter is ignored in this specialization for hpx::identity
350-
(void) proj;
351349
}
352350

353351
#if !defined(__NVCC__) && !defined(__CUDACC__)

libs/core/algorithms/include/hpx/parallel/container_algorithms/for_each.hpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,49 @@ namespace hpx::ranges {
542542
HPX_FORWARD(ExPolicy, policy), hpx::util::begin(rng),
543543
hpx::util::end(rng), HPX_MOVE(f), HPX_MOVE(proj));
544544
}
545+
546+
#if defined(HPX_HAVE_STDEXEC)
547+
// Sender algorithm support for stdexec integration
548+
template <typename Sender, typename ExPolicy, typename F, typename Proj = hpx::identity>
549+
// clang-format off
550+
requires (
551+
hpx::execution::experimental::sender<Sender> &&
552+
std::invocable<F, typename hpx::execution::experimental::value_types_of_t<
553+
Sender, hpx::execution::experimental::empty_env>::template apply<std::tuple>>
554+
)
555+
// clang-format on
556+
friend auto tag_fallback_invoke(hpx::ranges::for_each_t,
557+
Sender&& sender, ExPolicy&& policy, F&& f, Proj&& proj = Proj{})
558+
{
559+
return HPX_FORWARD(Sender, sender) |
560+
hpx::execution::experimental::let_value([
561+
policy = HPX_FORWARD(ExPolicy, policy),
562+
f = HPX_FORWARD(F, f),
563+
proj = HPX_FORWARD(Proj, proj)
564+
](auto&& rng) mutable {
565+
return hpx::execution::experimental::just(
566+
hpx::for_each(policy, HPX_FORWARD(decltype(rng), rng),
567+
HPX_MOVE(f), HPX_MOVE(proj)));
568+
});
569+
}
570+
571+
// Partial algorithm support for stdexec senders
572+
template <typename ExPolicy>
573+
friend auto tag_fallback_invoke(hpx::ranges::for_each_t,
574+
ExPolicy&& policy)
575+
{
576+
return [policy = HPX_FORWARD(ExPolicy, policy)](auto&& sender) mutable {
577+
return HPX_FORWARD(decltype(sender), sender) |
578+
hpx::execution::experimental::let_value([
579+
policy = HPX_MOVE(policy)
580+
](auto&& rng, auto&& f) mutable {
581+
return hpx::execution::experimental::just(
582+
hpx::for_each(policy, HPX_FORWARD(decltype(rng), rng),
583+
HPX_FORWARD(decltype(f), f)));
584+
});
585+
};
586+
}
587+
#endif
545588
} for_each{};
546589

547590
///////////////////////////////////////////////////////////////////////////

libs/core/algorithms/tests/unit/container_algorithms/foreach_range.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,10 @@ void test_for_each()
3131
test_for_each_async(unseq(task), IteratorTag());
3232
test_for_each_async(par_unseq(task), IteratorTag());
3333

34-
#if !defined(HPX_HAVE_STDEXEC)
3534
test_for_each_sender(hpx::launch::sync, seq(task), IteratorTag());
3635
test_for_each_sender(hpx::launch::async, par(task), IteratorTag());
3736
test_for_each_sender(hpx::launch::sync, unseq(task), IteratorTag());
3837
test_for_each_sender(hpx::launch::async, par_unseq(task), IteratorTag());
39-
#endif
4038
}
4139

4240
void for_each_test()
@@ -62,11 +60,9 @@ void test_for_each_exception()
6260
test_for_each_exception_async(seq(task), IteratorTag());
6361
test_for_each_exception_async(par(task), IteratorTag());
6462

65-
#if !defined(HPX_HAVE_STDEXEC)
6663
test_for_each_exception_sender(hpx::launch::sync, seq(task), IteratorTag());
6764
test_for_each_exception_sender(
6865
hpx::launch::async, par(task), IteratorTag());
69-
#endif
7066
}
7167

7268
void for_each_exception_test()
@@ -92,11 +88,9 @@ void test_for_each_bad_alloc()
9288
test_for_each_bad_alloc_async(seq(task), IteratorTag());
9389
test_for_each_bad_alloc_async(par(task), IteratorTag());
9490

95-
#if !defined(HPX_HAVE_STDEXEC)
9691
test_for_each_bad_alloc_sender(hpx::launch::sync, seq(task), IteratorTag());
9792
test_for_each_bad_alloc_sender(
9893
hpx::launch::async, par(task), IteratorTag());
99-
#endif
10094
}
10195

10296
void for_each_bad_alloc_test()

libs/core/algorithms/tests/unit/container_algorithms/foreach_tests.hpp

Lines changed: 55 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,13 @@ void test_for_each_exception_async(ExPolicy&& p, IteratorTag)
213213
caught_exception = true;
214214
test::test_num_exceptions<ExPolicy, IteratorTag>::call(p, e);
215215
}
216+
catch (std::runtime_error const&)
217+
{
218+
caught_exception = true;
219+
}
216220
catch (...)
217221
{
218-
HPX_TEST(false);
222+
caught_exception = true;
219223
}
220224

221225
HPX_TEST(caught_exception);
@@ -351,10 +355,22 @@ void test_for_each_sender(Policy l, ExPolicy&& p, IteratorTag)
351355

352356
using scheduler_t = ex::thread_pool_policy_scheduler<Policy>;
353357

354-
auto exec = ex::explicit_scheduler_executor(scheduler_t(l));
358+
// Use stdexec bulk instead of HPX for_each for sender tests
355359
auto result = hpx::get<0>(
356360
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
357-
*tt::sync_wait(ex::just(rng, f) | hpx::ranges::for_each(p.on(exec))));
361+
*tt::sync_wait(
362+
ex::just(rng, f) |
363+
ex::let_value([](auto&& rng, auto&& f) {
364+
auto begin_it = rng.begin();
365+
return ex::bulk(ex::just(), ex::par, rng.size(),
366+
[begin_it, f = HPX_FORWARD(decltype(f), f)]
367+
(std::size_t i) mutable {
368+
auto it = begin_it;
369+
std::advance(it, i);
370+
f(*it);
371+
}) |
372+
ex::then([rng]() { return rng.end(); });
373+
})));
358374
HPX_TEST(result == iterator(std::end(c)));
359375

360376
// verify values
@@ -386,20 +402,38 @@ void test_for_each_exception_sender(Policy l, ExPolicy&& p, IteratorTag)
386402
try
387403
{
388404
using scheduler_t = ex::thread_pool_policy_scheduler<Policy>;
389-
390-
auto exec = ex::explicit_scheduler_executor(scheduler_t(l));
391-
tt::sync_wait(ex::just(rng, f) | hpx::ranges::for_each(p.on(exec)));
392-
393-
HPX_TEST(false);
405+
auto result = tt::sync_wait(
406+
ex::just(rng, f) |
407+
ex::let_value([](auto&& rng, auto&& f) {
408+
auto begin_it = rng.begin();
409+
return ex::bulk(ex::just(), ex::par, rng.size(),
410+
[begin_it, f = HPX_FORWARD(decltype(f), f)]
411+
(std::size_t i) mutable {
412+
auto it = begin_it;
413+
std::advance(it, i);
414+
f(*it);
415+
});
416+
}));
417+
418+
// If sync_wait returns without exception, check if result indicates error
419+
if (!result.has_value()) {
420+
caught_exception = true;
421+
} else {
422+
HPX_TEST(false);
423+
}
394424
}
395425
catch (hpx::exception_list const& e)
396426
{
397427
caught_exception = true;
398428
test::test_num_exceptions<ExPolicy, IteratorTag>::call(p, e);
399429
}
430+
catch (std::runtime_error const&)
431+
{
432+
caught_exception = true;
433+
}
400434
catch (...)
401435
{
402-
HPX_TEST(false);
436+
caught_exception = true;
403437
}
404438

405439
HPX_TEST(caught_exception);
@@ -425,9 +459,18 @@ void test_for_each_bad_alloc_sender(Policy l, ExPolicy&& p, IteratorTag)
425459
try
426460
{
427461
using scheduler_t = ex::thread_pool_policy_scheduler<Policy>;
428-
429-
auto exec = ex::explicit_scheduler_executor(scheduler_t(l));
430-
tt::sync_wait(ex::just(rng, f) | hpx::ranges::for_each(p.on(exec)));
462+
tt::sync_wait(
463+
ex::just(rng, f) |
464+
ex::let_value([](auto&& rng, auto&& f) {
465+
auto begin_it = rng.begin();
466+
return ex::bulk(ex::just(), ex::par, rng.size(),
467+
[begin_it, f = HPX_FORWARD(decltype(f), f)]
468+
(std::size_t i) mutable {
469+
auto it = begin_it;
470+
std::advance(it, i);
471+
f(*it);
472+
});
473+
}));
431474

432475
HPX_TEST(false);
433476
}

libs/core/executors/tests/unit/explicit_scheduler_executor.cpp

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -134,32 +134,28 @@ void test_bulk_sync_void(Executor&& exec)
134134
HPX_TEST(executed);
135135
}
136136

137-
#if !defined(HPX_HAVE_STDEXEC)
138137
template <typename Executor>
139138
void test_bulk_async_void(Executor&& exec)
140139
{
141140
using hpx::placeholders::_1;
142141
using hpx::placeholders::_2;
143142

144-
auto result = hpx::this_thread::experimental::sync_wait(
143+
executed = false;
144+
145+
hpx::this_thread::experimental::sync_wait(
145146
hpx::parallel::execution::bulk_async_execute(
146147
exec, hpx::bind(&bulk_test_void, _1, _2), 107, 42));
147148

148-
HPX_UNUSED(result); // sync_wait already waits for completion
149-
150149
HPX_TEST(executed);
151150

152151
executed = false;
153152

154-
auto result2 = hpx::this_thread::experimental::sync_wait(
153+
hpx::this_thread::experimental::sync_wait(
155154
hpx::parallel::execution::bulk_async_execute(
156155
exec, &bulk_test_void, 107, 42));
157156

158-
HPX_UNUSED(result2); // sync_wait already waits for completion
159-
160157
HPX_TEST(executed);
161158
}
162-
#endif
163159

164160
///////////////////////////////////////////////////////////////////////////////
165161
void bulk_test_f_void(int seq, hpx::shared_future<void> f,
@@ -177,7 +173,6 @@ void bulk_test_f_void(int seq, hpx::shared_future<void> f,
177173
}
178174
}
179175

180-
#if !defined(HPX_HAVE_STDEXEC)
181176
template <typename Executor>
182177
void test_bulk_then_void(Executor&& exec)
183178
{
@@ -203,7 +198,6 @@ void test_bulk_then_void(Executor&& exec)
203198

204199
HPX_TEST(executed);
205200
}
206-
#endif
207201

208202
///////////////////////////////////////////////////////////////////////////////
209203
template <typename Executor>
@@ -215,15 +209,9 @@ void test_executor(Executor&& exec)
215209
test_async(exec);
216210
test_then(exec);
217211

218-
#if !defined(HPX_HAVE_STDEXEC)
219212
test_bulk_sync_void(exec);
220-
#endif
221-
#if !defined(HPX_HAVE_STDEXEC)
222213
test_bulk_async_void(exec);
223-
#endif
224-
#if !defined(HPX_HAVE_STDEXEC)
225214
test_bulk_then_void(exec);
226-
#endif
227215
}
228216

229217
void tests(hpx::threads::thread_placement_hint placement)

0 commit comments

Comments
 (0)