Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/shammath/include/shammath/matrix_exponential.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,8 @@ namespace shammath {
mat_set_nul<T>(F);

for (auto k = r - 1; k >= 0; k--) {
mat_set_identity<T>(I);
mat_set_identity<T>(Id);
mat_set_identity(I);
mat_set_identity(Id);
mat_set_nul<T>(B);
i32 cc = 0;

Expand All @@ -295,7 +295,7 @@ namespace shammath {
cc = q * k + j;
mat_axpy_beta<T, U>(bi_seq[cc], I, 1, B);
}
mat_set_identity<T>(Id);
mat_set_identity(Id);

i32 cond = (k >= 1);
mat_axpy_beta<T, U>(1, B, 1, F);
Expand Down Expand Up @@ -361,8 +361,8 @@ namespace shammath {
taylor_eval<T, U>(r, q, seq_bi, size_A, A, F, B, I, Id);

// squaring step
mat_set_identity<T>(Id);
mat_set_identity<T>(I);
mat_set_identity(Id);
mat_set_identity(I);

for (auto j = 1; j <= pw; j++) {
mat_copy<T>(I, Id);
Expand Down
59 changes: 35 additions & 24 deletions src/shammath/include/shammath/matrix_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,27 @@
#include <experimental/mdspan>
#include <array>

template<class T>
struct is_mdspan : std::false_type {};

template<class T, class Extents, class Layout, class Accessor>
struct is_mdspan<std::mdspan<T, Extents, Layout, Accessor>> : std::true_type {};

template<class T>
inline constexpr bool is_mdspan_v = is_mdspan<std::remove_cvref_t<T>>::value;

template<class T>
concept Mdspan = is_mdspan_v<T>;

template<class T, std::size_t Rank>
concept MdspanRank = Mdspan<T> && (std::remove_cvref_t<T>::rank() == Rank);

template<class T>
using mdspan_value_t = typename std::remove_cvref_t<T>::value_type;

template<class A, class B>
concept mdspan_same_value = std::same_as<mdspan_value_t<A>, mdspan_value_t<B>>;
Comment on lines +28 to +47

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The traits and concepts introduced here are defined in the global namespace. They should be moved inside the shammath namespace to prevent name collisions and maintain a clean global scope, consistent with the rest of the library.


namespace shammath {

/**
Expand All @@ -40,11 +61,12 @@ namespace shammath {
* the value returned by the function is used to set the corresponding
* element of the matrix.
*/
template<class T, class Extents, class Layout, class Accessor, class Func>
inline void mat_set_vals(const std::mdspan<T, Extents, Layout, Accessor> &input, Func &&func) {

shambase::check_functor_signature<T, int, int>(func);

template<class Func>
inline void mat_set_vals(MdspanRank<2> auto &input, Func &&func)
requires requires(Func f, int a, int b) {
{ f(a, b) } -> std::same_as<mdspan_value_t<decltype(input)>>;
}
Comment on lines +65 to +68

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

mdspan is a lightweight view and should be passed by value. Additionally, the requires clause uses std::same_as, which is overly restrictive as it prevents using functors that return types implicitly convertible to the matrix's value type (e.g., returning an int for a double matrix). Using std::convertible_to is more idiomatic and flexible.

Suggested change
inline void mat_set_vals(MdspanRank<2> auto &input, Func &&func)
requires requires(Func f, int a, int b) {
{ f(a, b) } -> std::same_as<mdspan_value_t<decltype(input)>>;
}
inline void mat_set_vals(MdspanRank<2> auto input, Func &&func)
requires requires(Func f, int a, int b) {
{ f(a, b) } -> std::convertible_to<mdspan_value_t<decltype(input)>>;
}

{
for (int i = 0; i < input.extent(0); i++) {
for (int j = 0; j < input.extent(1); j++) {
input(i, j) = func(i, j);
Expand Down Expand Up @@ -89,13 +111,13 @@ namespace shammath {
* diagonal (from the top-left to the bottom-right) set to 1, and all other
* elements set to 0.
*/
template<class T, class Extents, class Layout, class Accessor>
inline void mat_set_identity(const std::mdspan<T, Extents, Layout, Accessor> &input1) {
inline void mat_set_identity(MdspanRank<2> auto input) {
SHAM_ASSERT(input.extent(0) == input.extent(1));

SHAM_ASSERT(input1.extent(0) == input1.extent(1));
using T = mdspan_value_t<decltype(input)>;

mat_set_vals(input1, [](auto i, auto j) -> T {
return (i == j) ? 1 : 0;
mat_set_vals(input, [](int i, int j) -> T {
return i == j ? T{1} : T{0};
});
}

Expand Down Expand Up @@ -222,21 +244,10 @@ namespace shammath {
* from the first matrix and stores the result in the output matrix. The dimensions
* of both input matrices and the output matrix must be the same.
*/
template<
class T,
class Extents1,
class Extents2,
class Extents3,
class Layout1,
class Layout2,
class Layout3,
class Accessor1,
class Accessor2,
class Accessor3>
inline void mat_sub(
const std::mdspan<T, Extents1, Layout1, Accessor1> &input1,
const std::mdspan<T, Extents2, Layout2, Accessor2> &input2,
const std::mdspan<T, Extents3, Layout3, Accessor3> &output) {
MdspanRank<2> auto const &input1,
MdspanRank<2> auto const &input2,
MdspanRank<2> auto &output) {
Comment thread
tdavidcl marked this conversation as resolved.

SHAM_ASSERT(input1.extent(0) == output.extent(0));
SHAM_ASSERT(input1.extent(1) == output.extent(1));
Expand Down
4 changes: 2 additions & 2 deletions src/shammodels/ramses/src/modules/DragIntegrator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ void shammodels::basegodunov::modules::DragIntegrator<Tvec, TgridVec>::enable_ex
get_jacobian(id_a, mdspan_A);

// pre-processing step
shammath::mat_set_identity<Tscal>(mdspan_Id);
shammath::mat_set_identity(mdspan_Id);
shammath::mat_axpy_beta<Tscal, Tscal>(-mu, mdspan_Id, dt, mdspan_A);

// compute matrix exponential
Expand Down Expand Up @@ -670,7 +670,7 @@ void shammodels::basegodunov::modules::DragIntegrator<Tvec, TgridVec>::enable_ex
get_jacobian(id_a, mdspan_A);

// pre-processing step
shammath::mat_set_identity<Tscal>(mdspan_Id);
shammath::mat_set_identity(mdspan_Id);
shammath::mat_axpy_beta<Tscal, Tscal>(-mu, mdspan_Id, dt, mdspan_A);

// compute matrix exponential
Expand Down
Loading