Skip to content

Commit 83b4eb9

Browse files
committed
parallel scheduler uses cached mask
1 parent 55b419e commit 83b4eb9

5 files changed

Lines changed: 268 additions & 27 deletions

File tree

libs/core/executors/include/hpx/executors/parallel_scheduler.hpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,13 @@ namespace hpx::execution::experimental {
5959
// Get the parallel_scheduler from the child sender's
6060
// completion scheduler (completes_on pattern)
6161
auto par_sched = [&]() {
62-
if constexpr (hpx::is_invocable_v<
63-
hpx::execution::experimental::
64-
get_completion_scheduler_t<hpx::
65-
execution::experimental::
66-
set_value_t>,
67-
decltype(hpx::execution::experimental::
68-
get_env(child))>)
62+
if constexpr (
63+
hpx::is_invocable_v<
64+
hpx::execution::experimental::
65+
get_completion_scheduler_t<
66+
hpx::execution::experimental::set_value_t>,
67+
decltype(hpx::execution::experimental::get_env(
68+
child))>)
6969
{
7070
return hpx::execution::experimental::
7171
get_completion_scheduler<
@@ -93,15 +93,18 @@ namespace hpx::execution::experimental {
9393
constexpr bool is_parallel =
9494
!is_sequenced_policy_v<std::decay_t<decltype(pol.__get())>>;
9595

96+
constexpr bool is_unsequenced = is_unsequenced_bulk_policy_v<
97+
std::decay_t<decltype(pol.__get())>>;
98+
9699
// Pass the pre-cached PU mask so thread_pool_bulk_sender
97100
// skips its own full_mask() computation on every invocation.
98101
hpx::threads::mask_type pu_mask = par_sched.get_pu_mask();
99102
return hpx::execution::experimental::detail::
100103
thread_pool_bulk_sender<hpx::launch,
101104
std::decay_t<decltype(child)>,
102105
std::decay_t<decltype(iota_shape)>,
103-
std::decay_t<decltype(f)>, is_chunked, is_parallel>(
104-
HPX_MOVE(underlying),
106+
std::decay_t<decltype(f)>, is_chunked, is_parallel,
107+
is_unsequenced>(HPX_MOVE(underlying),
105108
HPX_FORWARD(decltype(child), child),
106109
HPX_MOVE(iota_shape), HPX_FORWARD(decltype(f), f),
107110
HPX_MOVE(pu_mask));

libs/core/executors/include/hpx/executors/scheduler_executor.hpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#if defined(HPX_HAVE_STDEXEC)
2222
#include <hpx/executors/detail/index_queue_spawning.hpp>
23+
#include <hpx/executors/parallel_scheduler.hpp>
2324
#endif
2425

2526
#include <cstddef>
@@ -47,10 +48,48 @@ namespace hpx::execution::experimental {
4748
{
4849
};
4950

51+
// parallel_scheduler wraps thread_pool_policy_scheduler; use the same
52+
// index_queue fast path with thread_pool_params<parallel_scheduler>
53+
// so pu_mask() can return the cached mask from get_pu_mask().
54+
template <>
55+
struct has_thread_pool_backend<parallel_scheduler> : std::true_type
56+
{
57+
};
58+
5059
// Helper to extract thread pool parameters from a scheduler
5160
template <typename Scheduler>
5261
struct thread_pool_params; // primary: not defined
5362

63+
template <>
64+
struct thread_pool_params<parallel_scheduler>
65+
{
66+
static auto* pool(parallel_scheduler const& sched)
67+
{
68+
return sched.get_underlying_scheduler().get_thread_pool();
69+
}
70+
static std::size_t first_core(parallel_scheduler const& sched)
71+
{
72+
return hpx::execution::experimental::get_first_core(
73+
sched.get_underlying_scheduler());
74+
}
75+
static std::size_t num_cores(parallel_scheduler const& sched)
76+
{
77+
return hpx::execution::experimental::processing_units_count(
78+
hpx::execution::experimental::null_parameters,
79+
sched.get_underlying_scheduler(),
80+
hpx::chrono::null_duration, 0);
81+
}
82+
static auto const& policy(parallel_scheduler const& sched)
83+
{
84+
return sched.get_underlying_scheduler().policy();
85+
}
86+
static hpx::threads::mask_type pu_mask(
87+
parallel_scheduler const& sched)
88+
{
89+
return sched.get_pu_mask();
90+
}
91+
};
92+
5493
template <typename Policy>
5594
struct thread_pool_params<thread_pool_policy_scheduler<Policy>>
5695
{

libs/core/executors/include/hpx/executors/thread_pool_scheduler.hpp

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
// Forward declaration
3030
namespace hpx::execution::experimental::detail {
3131
template <typename Policy, typename Sender, typename Shape, typename F,
32-
bool IsChunked, bool IsParallel>
32+
bool IsChunked = false, bool IsParallel = true,
33+
bool IsUnsequenced = false>
3334
class thread_pool_bulk_sender;
3435
}
3536
#endif
@@ -85,6 +86,19 @@ namespace hpx::execution::experimental {
8586
inline constexpr bool is_sequenced_policy_v<stdexec::unsequenced_policy> =
8687
true;
8788

89+
//True for unseq and par_unseq
90+
template <typename Policy>
91+
inline constexpr bool is_unsequenced_bulk_policy_v = false;
92+
93+
template <>
94+
inline constexpr bool
95+
is_unsequenced_bulk_policy_v<stdexec::unsequenced_policy> = true;
96+
97+
template <>
98+
inline constexpr bool
99+
is_unsequenced_bulk_policy_v<stdexec::parallel_unsequenced_policy> =
100+
true;
101+
88102
// Domain customization for stdexec bulk operations
89103
// Only the env-based transform_sender is provided. The early (no-env)
90104
// transform falls through to default_domain, and the late transform
@@ -129,12 +143,23 @@ namespace hpx::execution::experimental {
129143
constexpr bool is_parallel =
130144
!is_sequenced_policy_v<std::decay_t<decltype(pol.__get())>>;
131145

146+
constexpr bool is_unsequenced = is_unsequenced_bulk_policy_v<
147+
std::decay_t<decltype(pol.__get())>>;
148+
149+
// Pre-compute the PU mask once and pass it to the 5-arg
150+
// constructor to avoid the expensive full_mask() call (O(N^2))
151+
// that the 4-arg constructor would trigger on every bulk
152+
// operation.
153+
auto pu_mask =
154+
hpx::execution::experimental::get_processing_units_mask(sched);
155+
132156
return hpx::execution::experimental::detail::
133157
thread_pool_bulk_sender<Policy, std::decay_t<decltype(child)>,
134158
std::decay_t<decltype(iota_shape)>,
135-
std::decay_t<decltype(f)>, is_chunked, is_parallel>{
136-
HPX_MOVE(sched), HPX_FORWARD(decltype(child), child),
137-
HPX_MOVE(iota_shape), HPX_FORWARD(decltype(f), f)};
159+
std::decay_t<decltype(f)>, is_chunked, is_parallel,
160+
is_unsequenced>{HPX_MOVE(sched),
161+
HPX_FORWARD(decltype(child), child), HPX_MOVE(iota_shape),
162+
HPX_FORWARD(decltype(f), f), HPX_MOVE(pu_mask)};
138163
}
139164
};
140165

libs/core/executors/include/hpx/executors/thread_pool_scheduler_bulk.hpp

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,24 @@ namespace hpx::execution::experimental::detail {
9292
(n + static_cast<std::size_t>(num_threads) - 1) / num_threads);
9393
}
9494

95+
/// Round a chunk up to a multiple of 16 when it is
96+
/// smaller than size
97+
HPX_CXX_CORE_EXPORT constexpr std::uint32_t align_chunk_for_vectorization(
98+
std::uint32_t chunk, std::uint32_t const size) noexcept
99+
{
100+
constexpr std::uint32_t g = 16;
101+
if (chunk == 0 || chunk >= size)
102+
return chunk;
103+
std::uint64_t c = chunk;
104+
if (c % g != 0)
105+
{
106+
c = ((c + g - 1) / g) * g;
107+
}
108+
if (c > size)
109+
c = size;
110+
return static_cast<std::uint32_t>(c);
111+
}
112+
95113
// For bulk_unchunked: f(index, ...)
96114
HPX_CXX_CORE_EXPORT template <std::size_t... Is, typename F, typename T,
97115
typename Ts>
@@ -183,9 +201,8 @@ namespace hpx::execution::experimental::detail {
183201

184202
auto const i_begin =
185203
static_cast<std::size_t>(index) * op_state->chunk_size;
186-
auto const i_end =
187-
(std::min) (i_begin + op_state->chunk_size,
188-
static_cast<std::size_t>(op_state->size));
204+
auto const i_end = (std::min) (i_begin + op_state->chunk_size,
205+
static_cast<std::size_t>(op_state->size));
189206

190207
if constexpr (OperationState::is_chunked)
191208
{
@@ -195,14 +212,14 @@ namespace hpx::execution::experimental::detail {
195212
}
196213
else
197214
{
198-
// bulk_unchunked: f(index, values...) for each element
199-
// In unchunked case, chunk_size is 1
200-
// so each chunk will only have one element.
201-
// The regular bulk invocation will go through the is_chunked case.
215+
// bulk_unchunked: one element call f(shape_index, values...) per i.
202216
auto it = std::ranges::next(
203217
hpx::util::begin(op_state->shape), i_begin);
204-
bulk_scheduler_invoke_helper(
205-
index_pack_type{}, op_state->f, *it, ts);
218+
for (auto i = i_begin; i < i_end; ++i, ++it)
219+
{
220+
bulk_scheduler_invoke_helper(
221+
index_pack_type{}, op_state->f, *it, ts);
222+
}
206223
}
207224
}
208225

@@ -319,7 +336,8 @@ namespace hpx::execution::experimental::detail {
319336
// Otherwise, it will call set_value on the connected receiver.
320337
void finish() const
321338
{
322-
if (--(op_state->tasks_remaining.data_) == 0)
339+
if (op_state->tasks_remaining.data_.fetch_sub(
340+
1, std::memory_order_acq_rel) == 1)
323341
{
324342
if (op_state->bad_alloc_thrown.load(std::memory_order_relaxed))
325343
{
@@ -557,8 +575,16 @@ namespace hpx::execution::experimental::detail {
557575
}
558576
else
559577
{
560-
chunk_size = 1;
561-
num_chunks = size;
578+
chunk_size = get_bulk_scheduler_chunk_size(
579+
op_state->num_worker_threads, size);
580+
num_chunks = (size + chunk_size - 1) / chunk_size;
581+
}
582+
583+
if constexpr (OperationState::is_unsequenced &&
584+
OperationState::is_parallel)
585+
{
586+
chunk_size = align_chunk_for_vectorization(chunk_size, size);
587+
num_chunks = (size + chunk_size - 1) / chunk_size;
562588
}
563589

564590
// launch only as many tasks as we have chunks
@@ -723,6 +749,16 @@ namespace hpx::execution::experimental::detail {
723749
#endif
724750
};
725751

752+
#if !defined(HPX_HAVE_STDEXEC)
753+
// With stdexec, thread_pool_scheduler.hpp forward declares this template
754+
// with default arguments; without it, declare here so the definition below
755+
// does not repeat default template arguments.
756+
template <typename Policy, typename Sender, typename Shape, typename F,
757+
bool IsChunked = false, bool IsParallel = true,
758+
bool IsUnsequenced = false>
759+
class thread_pool_bulk_sender;
760+
#endif
761+
726762
// This sender represents bulk work that will be performed using the
727763
// thread_pool_scheduler.
728764
//
@@ -740,8 +776,8 @@ namespace hpx::execution::experimental::detail {
740776
// threads.
741777
//
742778
HPX_CXX_CORE_EXPORT template <typename Policy, typename Sender,
743-
typename Shape, typename F, bool IsChunked = false,
744-
bool IsParallel = true>
779+
typename Shape, typename F, bool IsChunked, bool IsParallel,
780+
bool IsUnsequenced>
745781
class thread_pool_bulk_sender
746782
{
747783
private:
@@ -885,6 +921,7 @@ namespace hpx::execution::experimental::detail {
885921
{
886922
static constexpr bool is_chunked = IsChunked;
887923
static constexpr bool is_parallel = IsParallel;
924+
static constexpr bool is_unsequenced = IsUnsequenced;
888925

889926
using operation_state_type =
890927
hpx::execution::experimental::connect_result_t<Sender,
@@ -899,9 +936,11 @@ namespace hpx::execution::experimental::detail {
899936
bool reverse_placement = false;
900937
bool allow_stealing = false;
901938
hpx::threads::mask_type pu_mask;
939+
902940
std::vector<hpx::util::cache_aligned_data<
903941
hpx::concurrency::detail::non_contiguous_index_queue<>>>
904942
queues;
943+
905944
HPX_NO_UNIQUE_ADDRESS std::decay_t<Shape> shape;
906945
HPX_NO_UNIQUE_ADDRESS std::decay_t<F> f;
907946
HPX_NO_UNIQUE_ADDRESS std::decay_t<Receiver> receiver;

0 commit comments

Comments
 (0)