Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,7 @@ namespace hpx::execution::experimental {

// Additional stdexec concepts and utilities needed for domain customization
HPX_CXX_CORE_EXPORT using stdexec::__completes_on;
HPX_CXX_CORE_EXPORT using stdexec::__starts_on;
HPX_CXX_CORE_EXPORT using stdexec::sender_expr_for;
HPX_CXX_CORE_EXPORT using stdexec::__sender_for;
} // namespace stdexec_internal
} // namespace hpx::execution::experimental

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,13 @@ namespace hpx::execution::experimental {
explicit_scheduler_executor const& exec, F&& f, S const& shape,
Ts&&... ts)
{
#if defined(HPX_HAVE_STDEXEC)
return bulk(schedule(exec.sched_), par, shape,
hpx::bind_back(HPX_FORWARD(F, f), HPX_FORWARD(Ts, ts)...));
#else
return bulk(schedule(exec.sched_), shape,
hpx::bind_back(HPX_FORWARD(F, f), HPX_FORWARD(Ts, ts)...));
#endif
}

// Range shape overload
Expand Down Expand Up @@ -206,8 +211,21 @@ namespace hpx::execution::experimental {

if constexpr (std::is_void_v<result_type>)
{
#if defined(HPX_HAVE_STDEXEC)
// stdexec::bulk requires integral shape and execution policy
using size_type = decltype(util::size(shape));
size_type const n = util::size(shape);
return bulk(schedule(exec.sched_), par, n,
Comment thread
hkaiser marked this conversation as resolved.
[shape, f = HPX_FORWARD(F, f),
... args = HPX_FORWARD(Ts, ts)](size_type i) mutable {
auto it = util::begin(shape);
std::advance(it, i);
HPX_INVOKE(f, *it, args...);
});
#else
return bulk(schedule(exec.sched_), shape,
hpx::bind_back(HPX_FORWARD(F, f), HPX_FORWARD(Ts, ts)...));
#endif
}
else
{
Expand All @@ -220,21 +238,19 @@ namespace hpx::execution::experimental {
auto f_wrapper = [](size_type const i,
result_vector_type& result_vector,
S const& shape, F& f, Ts&... ts) {
auto it = util::begin(shape);
std::advance(it, i);
result_vector[i] = HPX_INVOKE(f, *it, ts...);
auto it = std::begin(shape);
result_vector[i] = HPX_INVOKE(f, *std::next(it, i), ts...);
};

auto get_result = [](result_vector_type&& result_vector,
S const&, F&&, Ts&&...) {
return HPX_MOVE(result_vector);
};
auto get_result =
[](result_vector_type&& result_vector, S const&, F&&,
Ts&&...) { return HPX_MOVE(result_vector); };

#if defined(HPX_HAVE_STDEXEC)
return just(HPX_MOVE(result_vector), shape, HPX_FORWARD(F, f),
HPX_FORWARD(Ts, ts)...) |
continues_on(exec.sched_) |
bulk(shape_size, HPX_MOVE(f_wrapper)) |
bulk(par, shape_size, HPX_MOVE(f_wrapper)) |
then(HPX_MOVE(get_result));
#else
return then(
Expand Down Expand Up @@ -294,9 +310,15 @@ namespace hpx::execution::experimental {
auto pre_req =
when_all(keep_future(HPX_FORWARD(Future, predecessor)));

#if defined(HPX_HAVE_STDEXEC)
return transfer(HPX_MOVE(pre_req), exec.sched_) |
bulk(par, shape,
hpx::bind_back(HPX_FORWARD(F, f), HPX_FORWARD(Ts, ts)...));
#else
return transfer(HPX_MOVE(pre_req), exec.sched_) |
bulk(shape,
hpx::bind_back(HPX_FORWARD(F, f), HPX_FORWARD(Ts, ts)...));
#endif
}

// Range shape overload
Expand All @@ -320,9 +342,23 @@ namespace hpx::execution::experimental {
auto pre_req =
when_all(keep_future(HPX_FORWARD(Future, predecessor)));

#if defined(HPX_HAVE_STDEXEC)
using size_type = decltype(util::size(shape));
size_type const n = util::size(shape);
return transfer(HPX_MOVE(pre_req), exec.sched_) |
bulk(par, n,
[shape, f = HPX_FORWARD(F, f),
... args = HPX_FORWARD(Ts, ts)](
size_type i, auto&... receiver_args) mutable {
auto it = util::begin(shape);
std::advance(it, i);
HPX_INVOKE(f, *it, args..., receiver_args...);
});
#else
return transfer(HPX_MOVE(pre_req), exec.sched_) |
bulk(shape,
hpx::bind_back(HPX_FORWARD(F, f), HPX_FORWARD(Ts, ts)...));
#endif
}

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,21 @@ namespace hpx::execution::experimental {

if constexpr (std::is_void_v<result_type>)
{
#if defined(HPX_HAVE_STDEXEC)
// stdexec::bulk requires integral shape and execution policy
using size_type = decltype(hpx::util::size(shape));
size_type const n = hpx::util::size(shape);
return make_future(bulk(schedule(exec.sched_), par, n,
[shape, f = HPX_FORWARD(F, f),
... args = HPX_FORWARD(Ts, ts)](size_type i) mutable {
auto it = hpx::util::begin(shape);
std::advance(it, i);
HPX_INVOKE(f, *it, args...);
}));
#else
return make_future(bulk(schedule(exec.sched_), shape,
hpx::bind_back(HPX_FORWARD(F, f), HPX_FORWARD(Ts, ts)...)));
#endif
}
else
{
Expand Down Expand Up @@ -215,10 +228,17 @@ namespace hpx::execution::experimental {
});
};

#if defined(HPX_HAVE_STDEXEC)
start_detached(
bulk(transfer_just(exec.sched_, HPX_MOVE(promises),
HPX_FORWARD(F, f), shape, HPX_FORWARD(Ts, ts)...),
par, n, HPX_MOVE(f_helper)));
#else
start_detached(
bulk(transfer_just(exec.sched_, HPX_MOVE(promises),
HPX_FORWARD(F, f), shape, HPX_FORWARD(Ts, ts)...),
n, HPX_MOVE(f_helper)));
#endif

return results;
}
Expand All @@ -234,12 +254,29 @@ namespace hpx::execution::experimental {
using result_type = hpx::util::detail::invoke_deferred_result_t<F,
shape_element, Ts...>;

#if defined(HPX_HAVE_STDEXEC)
// stdexec::bulk requires integral shape and execution policy
using size_type = decltype(hpx::util::size(shape));
size_type const n = hpx::util::size(shape);
return hpx::util::void_guard<result_type>(),
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
*hpx::this_thread::experimental::sync_wait(
bulk(schedule(exec.sched_), par, n,
[shape, f = HPX_FORWARD(F, f),
... args = HPX_FORWARD(Ts, ts)](
size_type i) mutable {
auto it = hpx::util::begin(shape);
std::advance(it, i);
HPX_INVOKE(f, *it, args...);
}));
#else
return hpx::util::void_guard<result_type>(),
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
*hpx::this_thread::experimental::sync_wait(
bulk(schedule(exec.sched_), shape,
hpx::bind_back(
HPX_FORWARD(F, f), HPX_FORWARD(Ts, ts)...)));
#endif
}

template <typename F, typename S, typename Future, typename... Ts>
Expand All @@ -259,9 +296,23 @@ namespace hpx::execution::experimental {
auto pre_req =
when_all(keep_future(HPX_FORWARD(Future, predecessor)));

#if defined(HPX_HAVE_STDEXEC)
using size_type = decltype(hpx::util::size(shape));
size_type const n = hpx::util::size(shape);
auto loop =
bulk(transfer(HPX_MOVE(pre_req), exec.sched_), par, n,
[shape, f = HPX_FORWARD(F, f),
... args = HPX_FORWARD(Ts, ts)](
size_type i, auto&... receiver_args) mutable {
auto it = hpx::util::begin(shape);
std::advance(it, i);
HPX_INVOKE(f, *it, args..., receiver_args...);
});
#else
auto loop = bulk(transfer(HPX_MOVE(pre_req), exec.sched_),
shape,
hpx::bind_back(HPX_FORWARD(F, f), HPX_FORWARD(Ts, ts)...));
#endif

return make_future(HPX_MOVE(loop));
}
Expand All @@ -272,10 +323,19 @@ namespace hpx::execution::experimental {
when_all(keep_future(HPX_FORWARD(Future, predecessor)),
just(std::vector<result_type>(hpx::util::size(shape))));

#if defined(HPX_HAVE_STDEXEC)
using size_type = decltype(hpx::util::size(shape));
size_type const n = hpx::util::size(shape);
auto loop =
bulk(transfer(HPX_MOVE(pre_req), exec.sched_), par, n,
detail::captured_args_then(
HPX_FORWARD(F, f), HPX_FORWARD(Ts, ts)...));
#else
auto loop =
bulk(transfer(HPX_MOVE(pre_req), exec.sched_), shape,
detail::captured_args_then(
HPX_FORWARD(F, f), HPX_FORWARD(Ts, ts)...));
#endif

return make_future(then(
HPX_MOVE(loop), [](auto&&, std::vector<result_type>&& v) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,32 +67,23 @@ namespace hpx::execution::experimental {
// Concept to match bulk sender types
template <typename Sender>
concept bulk_chunked_or_unchunked_sender =
hpx::execution::experimental::stdexec_internal::sender_expr_for<Sender,
hpx::execution::experimental::stdexec_internal::__sender_for<Sender,
hpx::execution::experimental::bulk_chunked_t> ||
hpx::execution::experimental::stdexec_internal::sender_expr_for<Sender,
hpx::execution::experimental::stdexec_internal::__sender_for<Sender,
hpx::execution::experimental::bulk_unchunked_t>;

// Domain customization for stdexec bulk operations
// Only the env-based transform_sender is provided. The early (no-env)
// transform falls through to default_domain, and the late transform
// handles both completes_on and starts_on patterns at connection time.
// Following the stdexec parallel_scheduler pattern (set_value_t tag-based).
template <typename Policy>
struct thread_pool_domain : stdexec::default_domain
{
// transform_sender for bulk operations
// (following stdexec system_context.hpp pattern env-based only)
// (following stdexec parallel_scheduler pattern)
template <bulk_chunked_or_unchunked_sender Sender, typename Env>
auto transform_sender(Sender&& sndr, Env const& env) const noexcept
constexpr auto transform_sender(
hpx::execution::experimental::set_value_t, Sender&& sndr,
Env const& env) const noexcept
{
static_assert(
hpx::execution::experimental::stdexec_internal::__completes_on<
Sender, thread_pool_policy_scheduler<Policy>, Env> ||
hpx::execution::experimental::stdexec_internal::__starts_on<
Sender, thread_pool_policy_scheduler<Policy>, Env>,
"No thread_pool_policy_scheduler instance can be found in the "
"sender's attributes or receiver's environment "
"on which to schedule bulk work.");

auto sched = hpx::execution::experimental::get_scheduler(env);

// Extract bulk parameters using structured binding
Expand All @@ -103,9 +94,8 @@ namespace hpx::execution::experimental {
hpx::util::counting_shape(decltype(shape){0}, shape);

constexpr bool is_chunked =
!hpx::execution::experimental::stdexec_internal::
sender_expr_for<Sender,
hpx::execution::experimental::bulk_unchunked_t>;
!hpx::execution::experimental::stdexec_internal::__sender_for<
Sender, hpx::execution::experimental::bulk_unchunked_t>;

return hpx::execution::experimental::detail::
thread_pool_bulk_sender<Policy, std::decay_t<decltype(child)>,
Expand Down
8 changes: 4 additions & 4 deletions libs/core/executors/tests/unit/thread_pool_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2227,8 +2227,8 @@ void test_stdexec_domain_queries()
"bulk_chunked sender should satisfy "
"bulk_chunked_or_unchunked_sender concept");

auto transformed =
domain.transform_sender(std::move(chunked_sndr), env);
auto transformed = domain.transform_sender(
ex::set_value_t{}, std::move(chunked_sndr), env);

static_assert(is_thread_pool_bulk_sender<
std::decay_t<decltype(transformed)>>::value,
Expand All @@ -2250,8 +2250,8 @@ void test_stdexec_domain_queries()
"bulk_unchunked sender should satisfy "
"bulk_chunked_or_unchunked_sender concept");

auto transformed =
domain.transform_sender(std::move(unchunked_sndr), env);
auto transformed = domain.transform_sender(
ex::set_value_t{}, std::move(unchunked_sndr), env);

static_assert(is_thread_pool_bulk_sender<
std::decay_t<decltype(transformed)>>::value,
Expand Down
Loading