diff --git a/examples/sph/run_sph_rendering.py b/examples/sph/run_sph_rendering.py index ee4439568..78235f472 100644 --- a/examples/sph/run_sph_rendering.py +++ b/examples/sph/run_sph_rendering.py @@ -211,7 +211,16 @@ def H_profile(r): setup.apply_setup(gen_disc) model.do_vtk_dump("init_disc.vtk", True) +model.dump("tmp.sham") + +ctx = shamrock.Context() +model = shamrock.get_Model_SPH(context=ctx, vector_type="f64_3", sph_kernel="M4") +model.load_from_dump("tmp.sham") + +model.dump("tmp2.sham") + +exit() model.change_htolerances(coarse=1.3, fine=1.1) model.timestep() model.change_htolerances(coarse=1.1, fine=1.1) diff --git a/src/shammodels/sph/src/Model.cpp b/src/shammodels/sph/src/Model.cpp index d05168906..204599769 100644 --- a/src/shammodels/sph/src/Model.cpp +++ b/src/shammodels/sph/src/Model.cpp @@ -32,6 +32,7 @@ #include "shamrock/patch/PatchDataLayer.hpp" #include "shamrock/scheduler/DataInserterUtility.hpp" #include "shamrock/scheduler/PatchScheduler.hpp" +#include "shamrock/solvergraph/ScalarEdgeSerializable.hpp" #include "shamsys/NodeInstance.hpp" #include "shamsys/legacy/log.hpp" #include @@ -69,6 +70,10 @@ void shammodels::sph::Model::init() { PatchScheduler &sched = shambase::get_check_ref(ctx.sched); + auto time_edge = sched.synchronized_data.container.register_edge( + "time", shamrock::solvergraph::ScalarEdgeSerializable("time", "t")); + time_edge->value = 0; + sched.add_root_patch(); shamlog_debug_ln("Sys", "build local scheduler tables"); diff --git a/src/shamrock/include/shamrock/scheduler/PatchScheduler.hpp b/src/shamrock/include/shamrock/scheduler/PatchScheduler.hpp index 602504e69..034d210eb 100644 --- a/src/shamrock/include/shamrock/scheduler/PatchScheduler.hpp +++ b/src/shamrock/include/shamrock/scheduler/PatchScheduler.hpp @@ -23,15 +23,19 @@ #include "shambase/DistributedData.hpp" #include "shambase/stacktrace.hpp" #include "shambase/time.hpp" +#include "nlohmann/json_fwd.hpp" #include "shamalgs/collective/distributedDataComm.hpp" #include "shamrock/legacy/patch/utility/patch_field.hpp" #include "shamrock/solvergraph/NodeSetEdge.hpp" #include "shamrock/solvergraph/PatchDataLayerRefs.hpp" +#include "shamrock/solvergraph/SolverGraph.hpp" #include +#include #include #include #include #include +#include #include #include #include @@ -47,8 +51,25 @@ #include "shamrock/scheduler/HilbertLoadBalance.hpp" #include "shamrock/scheduler/PatchTree.hpp" #include "shamrock/scheduler/SchedulerPatchData.hpp" +#include "shamrock/solvergraph/IEdgeNamed.hpp" +#include "shamrock/solvergraph/JsonSerializable.hpp" #include "shamsys/legacy/sycl_handler.hpp" +inline std::unordered_map< + std::string, + std::function(const nlohmann::json &j)>> + deser_map = {}; + +/// Data stored within the scheduler that are garanteed to be in sink across all ranks +struct SynchronizedData { + shamrock::solvergraph::SolverGraph container + = shamrock::solvergraph::SolverGraph::with_constraint( + std::nullopt, shamrock::solvergraph::json_serializable_edge_constraint); + + nlohmann::json to_json(); + + void from_json(const nlohmann::json &j); +}; struct PatchSchedulerConfig { u64 split_load_value = 0_u64; u64 merge_load_value = 0_u64; @@ -98,9 +119,10 @@ class PatchScheduler { u64 crit_patch_split; ///< splitting limit (if load value > crit_patch_split => patch split) u64 crit_patch_merge; ///< merging limit (if load value < crit_patch_merge => patch merge) - SchedulerPatchList patch_list; ///< handle the list of the patches of the scheduler - SchedulerPatchData patch_data; ///< handle the data of the patches of the scheduler - PatchTree patch_tree; ///< handle the tree structure of the patches + SchedulerPatchList patch_list; ///< handle the list of the patches of the scheduler + SchedulerPatchData patch_data; ///< handle the data of the patches of the scheduler + PatchTree patch_tree; ///< handle the tree structure of the patches + SynchronizedData synchronized_data; ///< data that is synchroneous across all ranks // using unordered set is not an issue since we use the find command after std::unordered_set owned_patch_id; ///< list of owned patch ids updated with diff --git a/src/shamrock/include/shamrock/solvergraph/IEdgeNamed.hpp b/src/shamrock/include/shamrock/solvergraph/IEdgeNamed.hpp index 17ac3a062..d0a0c1269 100644 --- a/src/shamrock/include/shamrock/solvergraph/IEdgeNamed.hpp +++ b/src/shamrock/include/shamrock/solvergraph/IEdgeNamed.hpp @@ -29,6 +29,7 @@ namespace shamrock::solvergraph { virtual std::string _impl_get_dot_label() const { return name; } virtual std::string _impl_get_tex_symbol() const { return "{" + texsymbol + "}"; } + virtual std::string get_raw_tex_symbol() const { return texsymbol; } }; } // namespace shamrock::solvergraph diff --git a/src/shamrock/include/shamrock/solvergraph/JsonSerializable.hpp b/src/shamrock/include/shamrock/solvergraph/JsonSerializable.hpp new file mode 100644 index 000000000..602a3924a --- /dev/null +++ b/src/shamrock/include/shamrock/solvergraph/JsonSerializable.hpp @@ -0,0 +1,38 @@ +// -------------------------------------------------------// +// +// SHAMROCK code for hydrodynamics +// Copyright (c) 2021-2026 Timothée David--Cléris +// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1 +// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information +// +// -------------------------------------------------------// + +#pragma once + +/** + * @file JsonSerializable.hpp + * @author Timothée David--Cléris (tim.shamrock@proton.me) + * @brief + */ + +#include "shamrock/solvergraph/IEdge.hpp" +#include +#include + +namespace shamrock::solvergraph { + + struct JsonSerializable { + virtual ~JsonSerializable() {}; + + virtual void to_json(nlohmann::json &j) = 0; + virtual void from_json(const nlohmann::json &j) = 0; + + virtual std::string type_name() = 0; + }; + + inline bool json_serializable_edge_constraint( + const std::shared_ptr &edge) { + // check that the edge can be cross-casted to JsonSerializable + return bool(std::dynamic_pointer_cast(edge)); + }; +} // namespace shamrock::solvergraph diff --git a/src/shamrock/include/shamrock/solvergraph/ScalarEdgeSerializable.hpp b/src/shamrock/include/shamrock/solvergraph/ScalarEdgeSerializable.hpp new file mode 100644 index 000000000..57f7c487f --- /dev/null +++ b/src/shamrock/include/shamrock/solvergraph/ScalarEdgeSerializable.hpp @@ -0,0 +1,85 @@ +// -------------------------------------------------------// +// +// SHAMROCK code for hydrodynamics +// Copyright (c) 2021-2026 Timothée David--Cléris +// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1 +// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information +// +// -------------------------------------------------------// + +#pragma once + +/** + * @file ScalarEdgeSerializable.hpp + * @author Timothée David--Cléris (tim.shamrock@proton.me) + * @brief + * + */ + +#include "shambase/exception.hpp" +#include "shambase/pre_main_call.hpp" +#include "shambase/string.hpp" +#include "shambase/type_name_info.hpp" +#include "nlohmann/json_fwd.hpp" +#include "shamrock/scheduler/PatchScheduler.hpp" +#include "shamrock/solvergraph/IEdge.hpp" +#include "shamrock/solvergraph/ScalarEdge.hpp" +#include +#include + +namespace shamrock::solvergraph { + + template + class ScalarEdgeSerializable : public ScalarEdge, public JsonSerializable { + public: + using ScalarEdge::ScalarEdge; + using ScalarEdge::value; + + virtual void to_json(nlohmann::json &j) { + j = nlohmann::json{ + {"type", type_name()}, + {"value", value}, + {"label", this->get_label()}, + {"tex_symbol", this->get_raw_tex_symbol()}}; + }; + + virtual void from_json(const nlohmann::json &j) { + std::string type = j.at("type"); + + if (type != type_name()) { + throw shambase::make_except_with_loc(shambase::format( + "error when deserializing ScalarEdgeSerializable, expected type info " + "\"{}\" but got \"{}\"", + type_name(), + type)); + } + + value = j.at("value").get(); + }; + + inline static std::string type_name_static() { + return "ScalarEdgeSerializable<" + shambase::get_type_name() + ">"; + } + + virtual std::string type_name() { return type_name_static(); }; + }; + +} // namespace shamrock::solvergraph + +template +void register_ctor_deser() { + + auto ctor = [](const nlohmann::json &j) -> std::shared_ptr { + std::string label = j.at("label").get(); + std::string tex_symbol = j.at("tex_symbol").get(); + + return std::make_shared>( + label, tex_symbol); + }; + + deser_map.insert({shamrock::solvergraph::ScalarEdgeSerializable::type_name_static(), ctor}); +} + +PRE_MAIN_FUNCTION_CALL([&]() { + register_ctor_deser(); +}) diff --git a/src/shamrock/include/shamrock/solvergraph/SolverGraph.hpp b/src/shamrock/include/shamrock/solvergraph/SolverGraph.hpp index 6f4006afa..cd744f06b 100644 --- a/src/shamrock/include/shamrock/solvergraph/SolverGraph.hpp +++ b/src/shamrock/include/shamrock/solvergraph/SolverGraph.hpp @@ -16,14 +16,40 @@ * */ +#include "shambase/exception.hpp" #include "shambase/memory.hpp" #include "shamrock/solvergraph/IEdge.hpp" #include "shamrock/solvergraph/INode.hpp" #include +#include #include +#include +#include +#include namespace shamrock::solvergraph { + struct SolverGraphContraint { + std::optional &)>> _validate_node; + std::optional &)>> _validate_edge; + + inline static SolverGraphContraint no_constraint() { return {std::nullopt, std::nullopt}; } + + inline bool validate_node(const std::shared_ptr &node) { + if (_validate_node) { + return (*_validate_node)(node); + } + return true; + } + + inline bool validate_edge(const std::shared_ptr &edge) { + if (_validate_edge) { + return (*_validate_edge)(edge); + } + return true; + } + }; + /** * @brief A graph container for managing solver nodes and edges with type-safe access. * @@ -55,16 +81,31 @@ namespace shamrock::solvergraph { */ class SolverGraph { /// Registry of nodes by name - std::unordered_map> nodes; + std::unordered_map> nodes = {}; /// Registry of edges by name - std::unordered_map> edges; + std::unordered_map> edges = {}; + + SolverGraphContraint constraint = SolverGraphContraint::no_constraint(); public: /////////////////////////////////////// // base getters and setters /////////////////////////////////////// + SolverGraph() = default; + + SolverGraph( + std::optional &)>> _validate_node, + std::optional &)>> _validate_edge) + : constraint(SolverGraphContraint{_validate_node, _validate_edge}) {} + + inline static SolverGraph with_constraint( + std::optional &)>> _validate_node, + std::optional &)>> _validate_edge) { + return SolverGraph{_validate_node, _validate_edge}; + } + /** * @brief Register a node with the graph using a shared pointer. * @@ -74,6 +115,12 @@ namespace shamrock::solvergraph { */ inline std::shared_ptr register_node_ptr_base( const std::string &name, std::shared_ptr node) { + + if (!constraint.validate_node(node)) { + throw shambase::make_except_with_loc( + "node validation failed under solvergraph constraint"); + } + const auto [it, inserted] = nodes.try_emplace(name, std::move(node)); if (!inserted) { shambase::throw_with_loc( @@ -91,6 +138,12 @@ namespace shamrock::solvergraph { */ inline std::shared_ptr register_edge_ptr_base( const std::string &name, std::shared_ptr edge) { + + if (!constraint.validate_edge(edge)) { + throw shambase::make_except_with_loc( + "edge validation failed under solvergraph constraint"); + } + const auto [it, inserted] = edges.try_emplace(name, std::move(edge)); if (!inserted) { shambase::throw_with_loc( @@ -336,6 +389,24 @@ namespace shamrock::solvergraph { inline const T &get_edge_ref(const std::string &name) const { return shambase::get_check_ref(get_edge_ptr(name)); } + + std::vector get_edge_names() { + std::vector ret{}; + + for (auto &[k, e] : edges) { + ret.push_back(k); + } + return ret; + } + + std::vector get_node_names() { + std::vector ret{}; + + for (auto &[k, n] : nodes) { + ret.push_back(k); + } + return ret; + } }; } // namespace shamrock::solvergraph diff --git a/src/shamrock/src/io/ShamrockDump.cpp b/src/shamrock/src/io/ShamrockDump.cpp index e3e7c97e8..e66ffa2fd 100644 --- a/src/shamrock/src/io/ShamrockDump.cpp +++ b/src/shamrock/src/io/ShamrockDump.cpp @@ -200,6 +200,9 @@ namespace shamrock { sched.patch_list = jmeta_patch.at("patchlist").get(); sched.patch_tree = jmeta_patch.at("patchtree").get(); sched.patch_data.sim_box.from_json(jmeta_patch.at("sim_box")); + if (jmeta_patch.contains("synchronized_data")) { + sched.synchronized_data.from_json(jmeta_patch.at("synchronized_data")); + } // edit patch owner to fit in new world size, or spread if more processes now // a bit dirty but gets the job done for now diff --git a/src/shamrock/src/scheduler/PatchScheduler.cpp b/src/shamrock/src/scheduler/PatchScheduler.cpp index 9faace65f..901cf24df 100644 --- a/src/shamrock/src/scheduler/PatchScheduler.cpp +++ b/src/shamrock/src/scheduler/PatchScheduler.cpp @@ -18,6 +18,7 @@ #include "shambase/stacktrace.hpp" #include "shambase/string.hpp" #include "shambase/time.hpp" +#include "nlohmann/json_fwd.hpp" #include "shambackends/math.hpp" #include "shambackends/typeAliasVec.hpp" #include "shamrock/legacy/patch/base/patchdata.hpp" @@ -212,8 +213,8 @@ PatchScheduler::PatchScheduler( u64 crit_merge) : pdl_ptr(pdl_ptr), patch_data( - pdl_ptr, - {{0, 0, 0}, {max_axis_patch_coord, max_axis_patch_coord, max_axis_patch_coord}}) { + pdl_ptr, {{0, 0, 0}, {max_axis_patch_coord, max_axis_patch_coord, max_axis_patch_coord}}), + synchronized_data() { crit_patch_split = crit_split; crit_patch_merge = crit_merge; @@ -1021,11 +1022,36 @@ nlohmann::json PatchScheduler::serialize_patch_metadata() { nlohmann::json jsim_box; patch_data.sim_box.to_json(jsim_box); + nlohmann::json jsynchro_data = synchronized_data.to_json(); + return { {"patchtree", patch_tree}, {"patchlist", patch_list}, {"patchdata_layout", pdl_old()}, {"sim_box", jsim_box}, {"crit_patch_split", crit_patch_split}, - {"crit_patch_merge", crit_patch_merge}}; + {"crit_patch_merge", crit_patch_merge}, + {"synchronized_data", jsynchro_data}}; +} + +nlohmann::json SynchronizedData::to_json() { + + nlohmann::json edges{}; + + using namespace shamrock::solvergraph; + + for (const std::string &edgen : container.get_edge_names()) { + container.get_edge_ref(edgen).to_json(edges[edgen]); + } + + return {{"edges", edges}}; +} + +void SynchronizedData::from_json(const nlohmann::json &j) { + + for (auto &el : j.at("edges").items()) { + std::string type = el.value().at("type"); + auto &deserializer = deser_map.at(type); + container.register_edge_ptr_base(el.key(), deserializer(el.value())); + } }