-
Notifications
You must be signed in to change notification settings - Fork 32
[Draft] showcase potential simplification of matrix op using concepts #1823
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -25,6 +25,27 @@ | |||||||||||||||||
| #include <experimental/mdspan> | ||||||||||||||||||
| #include <array> | ||||||||||||||||||
|
|
||||||||||||||||||
| template<class T> | ||||||||||||||||||
| struct is_mdspan : std::false_type {}; | ||||||||||||||||||
|
|
||||||||||||||||||
| template<class T, class Extents, class Layout, class Accessor> | ||||||||||||||||||
| struct is_mdspan<std::mdspan<T, Extents, Layout, Accessor>> : std::true_type {}; | ||||||||||||||||||
|
|
||||||||||||||||||
| template<class T> | ||||||||||||||||||
| inline constexpr bool is_mdspan_v = is_mdspan<std::remove_cvref_t<T>>::value; | ||||||||||||||||||
|
|
||||||||||||||||||
| template<class T> | ||||||||||||||||||
| concept Mdspan = is_mdspan_v<T>; | ||||||||||||||||||
|
|
||||||||||||||||||
| template<class T, std::size_t Rank> | ||||||||||||||||||
| concept MdspanRank = Mdspan<T> && (std::remove_cvref_t<T>::rank() == Rank); | ||||||||||||||||||
|
|
||||||||||||||||||
| template<class T> | ||||||||||||||||||
| using mdspan_value_t = typename std::remove_cvref_t<T>::value_type; | ||||||||||||||||||
|
|
||||||||||||||||||
| template<class A, class B> | ||||||||||||||||||
| concept mdspan_same_value = std::same_as<mdspan_value_t<A>, mdspan_value_t<B>>; | ||||||||||||||||||
|
|
||||||||||||||||||
| namespace shammath { | ||||||||||||||||||
|
|
||||||||||||||||||
| /** | ||||||||||||||||||
|
|
@@ -40,11 +61,12 @@ namespace shammath { | |||||||||||||||||
| * the value returned by the function is used to set the corresponding | ||||||||||||||||||
| * element of the matrix. | ||||||||||||||||||
| */ | ||||||||||||||||||
| template<class T, class Extents, class Layout, class Accessor, class Func> | ||||||||||||||||||
| inline void mat_set_vals(const std::mdspan<T, Extents, Layout, Accessor> &input, Func &&func) { | ||||||||||||||||||
|
|
||||||||||||||||||
| shambase::check_functor_signature<T, int, int>(func); | ||||||||||||||||||
|
|
||||||||||||||||||
| template<class Func> | ||||||||||||||||||
| inline void mat_set_vals(MdspanRank<2> auto &input, Func &&func) | ||||||||||||||||||
| requires requires(Func f, int a, int b) { | ||||||||||||||||||
| { f(a, b) } -> std::same_as<mdspan_value_t<decltype(input)>>; | ||||||||||||||||||
| } | ||||||||||||||||||
|
Comment on lines
+65
to
+68
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mdspan is a lightweight view and should be passed by value. Additionally, the requires clause uses std::same_as, which is overly restrictive as it prevents using functors that return types implicitly convertible to the matrix's value type (e.g., returning an int for a double matrix). Using std::convertible_to is more idiomatic and flexible.
Suggested change
|
||||||||||||||||||
| { | ||||||||||||||||||
| for (int i = 0; i < input.extent(0); i++) { | ||||||||||||||||||
| for (int j = 0; j < input.extent(1); j++) { | ||||||||||||||||||
| input(i, j) = func(i, j); | ||||||||||||||||||
|
|
@@ -89,13 +111,13 @@ namespace shammath { | |||||||||||||||||
| * diagonal (from the top-left to the bottom-right) set to 1, and all other | ||||||||||||||||||
| * elements set to 0. | ||||||||||||||||||
| */ | ||||||||||||||||||
| template<class T, class Extents, class Layout, class Accessor> | ||||||||||||||||||
| inline void mat_set_identity(const std::mdspan<T, Extents, Layout, Accessor> &input1) { | ||||||||||||||||||
| inline void mat_set_identity(MdspanRank<2> auto input) { | ||||||||||||||||||
| SHAM_ASSERT(input.extent(0) == input.extent(1)); | ||||||||||||||||||
|
|
||||||||||||||||||
| SHAM_ASSERT(input1.extent(0) == input1.extent(1)); | ||||||||||||||||||
| using T = mdspan_value_t<decltype(input)>; | ||||||||||||||||||
|
|
||||||||||||||||||
| mat_set_vals(input1, [](auto i, auto j) -> T { | ||||||||||||||||||
| return (i == j) ? 1 : 0; | ||||||||||||||||||
| mat_set_vals(input, [](int i, int j) -> T { | ||||||||||||||||||
| return i == j ? T{1} : T{0}; | ||||||||||||||||||
| }); | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -222,21 +244,10 @@ namespace shammath { | |||||||||||||||||
| * from the first matrix and stores the result in the output matrix. The dimensions | ||||||||||||||||||
| * of both input matrices and the output matrix must be the same. | ||||||||||||||||||
| */ | ||||||||||||||||||
| template< | ||||||||||||||||||
| class T, | ||||||||||||||||||
| class Extents1, | ||||||||||||||||||
| class Extents2, | ||||||||||||||||||
| class Extents3, | ||||||||||||||||||
| class Layout1, | ||||||||||||||||||
| class Layout2, | ||||||||||||||||||
| class Layout3, | ||||||||||||||||||
| class Accessor1, | ||||||||||||||||||
| class Accessor2, | ||||||||||||||||||
| class Accessor3> | ||||||||||||||||||
| inline void mat_sub( | ||||||||||||||||||
| const std::mdspan<T, Extents1, Layout1, Accessor1> &input1, | ||||||||||||||||||
| const std::mdspan<T, Extents2, Layout2, Accessor2> &input2, | ||||||||||||||||||
| const std::mdspan<T, Extents3, Layout3, Accessor3> &output) { | ||||||||||||||||||
| MdspanRank<2> auto const &input1, | ||||||||||||||||||
| MdspanRank<2> auto const &input2, | ||||||||||||||||||
| MdspanRank<2> auto &output) { | ||||||||||||||||||
|
tdavidcl marked this conversation as resolved.
|
||||||||||||||||||
|
|
||||||||||||||||||
| SHAM_ASSERT(input1.extent(0) == output.extent(0)); | ||||||||||||||||||
| SHAM_ASSERT(input1.extent(1) == output.extent(1)); | ||||||||||||||||||
|
|
||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The traits and concepts introduced here are defined in the global namespace. They should be moved inside the
shammathnamespace to prevent name collisions and maintain a clean global scope, consistent with the rest of the library.