diff --git a/src/shammath/include/shammath/matrix_exponential.hpp b/src/shammath/include/shammath/matrix_exponential.hpp index 2f795387c..1d9fed137 100644 --- a/src/shammath/include/shammath/matrix_exponential.hpp +++ b/src/shammath/include/shammath/matrix_exponential.hpp @@ -284,8 +284,8 @@ namespace shammath { mat_set_nul(F); for (auto k = r - 1; k >= 0; k--) { - mat_set_identity(I); - mat_set_identity(Id); + mat_set_identity(I); + mat_set_identity(Id); mat_set_nul(B); i32 cc = 0; @@ -295,7 +295,7 @@ namespace shammath { cc = q * k + j; mat_axpy_beta(bi_seq[cc], I, 1, B); } - mat_set_identity(Id); + mat_set_identity(Id); i32 cond = (k >= 1); mat_axpy_beta(1, B, 1, F); @@ -361,8 +361,8 @@ namespace shammath { taylor_eval(r, q, seq_bi, size_A, A, F, B, I, Id); // squaring step - mat_set_identity(Id); - mat_set_identity(I); + mat_set_identity(Id); + mat_set_identity(I); for (auto j = 1; j <= pw; j++) { mat_copy(I, Id); diff --git a/src/shammath/include/shammath/matrix_op.hpp b/src/shammath/include/shammath/matrix_op.hpp index dde113f36..6ed53f2eb 100644 --- a/src/shammath/include/shammath/matrix_op.hpp +++ b/src/shammath/include/shammath/matrix_op.hpp @@ -25,6 +25,27 @@ #include #include +template +struct is_mdspan : std::false_type {}; + +template +struct is_mdspan> : std::true_type {}; + +template +inline constexpr bool is_mdspan_v = is_mdspan>::value; + +template +concept Mdspan = is_mdspan_v; + +template +concept MdspanRank = Mdspan && (std::remove_cvref_t::rank() == Rank); + +template +using mdspan_value_t = typename std::remove_cvref_t::value_type; + +template +concept mdspan_same_value = std::same_as, mdspan_value_t>; + namespace shammath { /** @@ -40,11 +61,12 @@ namespace shammath { * the value returned by the function is used to set the corresponding * element of the matrix. */ - template - inline void mat_set_vals(const std::mdspan &input, Func &&func) { - - shambase::check_functor_signature(func); - + template + inline void mat_set_vals(MdspanRank<2> auto &input, Func &&func) + requires requires(Func f, int a, int b) { + { f(a, b) } -> std::same_as>; + } + { for (int i = 0; i < input.extent(0); i++) { for (int j = 0; j < input.extent(1); j++) { input(i, j) = func(i, j); @@ -89,13 +111,13 @@ namespace shammath { * diagonal (from the top-left to the bottom-right) set to 1, and all other * elements set to 0. */ - template - inline void mat_set_identity(const std::mdspan &input1) { + inline void mat_set_identity(MdspanRank<2> auto input) { + SHAM_ASSERT(input.extent(0) == input.extent(1)); - SHAM_ASSERT(input1.extent(0) == input1.extent(1)); + using T = mdspan_value_t; - mat_set_vals(input1, [](auto i, auto j) -> T { - return (i == j) ? 1 : 0; + mat_set_vals(input, [](int i, int j) -> T { + return i == j ? T{1} : T{0}; }); } @@ -222,21 +244,10 @@ namespace shammath { * from the first matrix and stores the result in the output matrix. The dimensions * of both input matrices and the output matrix must be the same. */ - template< - class T, - class Extents1, - class Extents2, - class Extents3, - class Layout1, - class Layout2, - class Layout3, - class Accessor1, - class Accessor2, - class Accessor3> inline void mat_sub( - const std::mdspan &input1, - const std::mdspan &input2, - const std::mdspan &output) { + MdspanRank<2> auto const &input1, + MdspanRank<2> auto const &input2, + MdspanRank<2> auto &output) { SHAM_ASSERT(input1.extent(0) == output.extent(0)); SHAM_ASSERT(input1.extent(1) == output.extent(1)); diff --git a/src/shammodels/ramses/src/modules/DragIntegrator.cpp b/src/shammodels/ramses/src/modules/DragIntegrator.cpp index fca85eaba..51c1085e3 100644 --- a/src/shammodels/ramses/src/modules/DragIntegrator.cpp +++ b/src/shammodels/ramses/src/modules/DragIntegrator.cpp @@ -481,7 +481,7 @@ void shammodels::basegodunov::modules::DragIntegrator::enable_ex get_jacobian(id_a, mdspan_A); // pre-processing step - shammath::mat_set_identity(mdspan_Id); + shammath::mat_set_identity(mdspan_Id); shammath::mat_axpy_beta(-mu, mdspan_Id, dt, mdspan_A); // compute matrix exponential @@ -670,7 +670,7 @@ void shammodels::basegodunov::modules::DragIntegrator::enable_ex get_jacobian(id_a, mdspan_A); // pre-processing step - shammath::mat_set_identity(mdspan_Id); + shammath::mat_set_identity(mdspan_Id); shammath::mat_axpy_beta(-mu, mdspan_Id, dt, mdspan_A); // compute matrix exponential