Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions src/smith/physics/solid_mechanics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1264,15 +1264,32 @@ class SolidMechanics<order, dim, Parameters<parameter_space...>, std::integer_se
{
SLIC_ERROR_ROOT_IF(bcs.size() == 0, "Adjoint load container size must be greater than 0 in the solid mechanics.");

auto reaction_adjoint_load = bcs.find("reactions");

SLIC_ERROR_ROOT_IF(reaction_adjoint_load == bcs.end(), "Adjoint load for \"reaction\" not found.");
for (const auto& [name, bc] : bcs) {
SLIC_ERROR_ROOT_IF(!trySetDualAdjointBc(name, bc),
std::format("Unknown dual adjoint BC '{}' for solid mechanics module '{}'.", name, name_));
}
}

if (reaction_adjoint_load != bcs.end()) {
reactions_adjoint_bcs_ = reaction_adjoint_load->second;
protected:
/**
* @brief Apply a single dual-adjoint boundary condition if this class owns the named dual
*
* @param[in] dual_name Name of the dual-adjoint BC to apply
* @param[in] bc Boundary condition values for the dual adjoint
* @return true if the key was recognized and applied
* @return false if the key is not owned by this class
*/
virtual bool trySetDualAdjointBc(const std::string& dual_name, const smith::FiniteElementState& bc)
{
if (dual_name == "reactions") {
reactions_adjoint_bcs_ = bc;
return true;
}

return false;
}

public:
/// @overload
void reverseAdjointTimestep() override
{
Expand Down
151 changes: 105 additions & 46 deletions src/smith/physics/solid_mechanics_contact.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,13 @@ class SolidMechanicsContact<order, dim, Parameters<parameter_space...>,
: SolidMechanicsBase(std::move(solver), timestepping_opts, physics_name, smith_mesh, parameter_names, cycle, time,
checkpoint_to_disk, use_warm_start),
contact_(BasePhysics::mfemParMesh()),
forces_(StateManager::newDual(displacement_.space(), detail::addPrefix(physics_name, "contact_forces")))
forces_(StateManager::newDual(displacement_.space(), detail::addPrefix(physics_name, "contact_forces"))),
contact_force_adjoint_bcs_(displacement_.space(), detail::addPrefix(physics_name, "contact_forces_adjoint_bcs"))
{
forces_ = 0;
duals_.push_back(&forces_);
contact_force_adjoint_bcs_ = 0.0;
this->dual_adjoints_.push_back(&contact_force_adjoint_bcs_);
}

/**
Expand All @@ -112,17 +115,21 @@ class SolidMechanicsContact<order, dim, Parameters<parameter_space...>,
std::shared_ptr<smith::Mesh> smith_mesh, int cycle = 0, double time = 0.0)
: SolidMechanicsBase(input_options, physics_name, smith_mesh, cycle, time),
contact_(BasePhysics::mfemParMesh()),
forces_(StateManager::newDual(displacement_.space(), detail::addPrefix(physics_name, "contact_forces")))
forces_(StateManager::newDual(displacement_.space(), detail::addPrefix(physics_name, "contact_forces"))),
contact_force_adjoint_bcs_(displacement_.space(), detail::addPrefix(physics_name, "contact_forces_adjoint_bcs"))
{
forces_ = 0;
duals_.push_back(&forces_);
contact_force_adjoint_bcs_ = 0.0;
this->dual_adjoints_.push_back(&contact_force_adjoint_bcs_);
}

/// @overload
void resetStates(int cycle = 0, double time = 0.0) override
{
SolidMechanicsBase::resetStates(cycle, time);
forces_ = 0.0;
contact_force_adjoint_bcs_ = 0.0;
for (auto& [_, force] : contact_interaction_forces_) {
if (force) {
*force = 0.0;
Expand All @@ -138,6 +145,7 @@ class SolidMechanicsContact<order, dim, Parameters<parameter_space...>,
mfem::Vector p(contact_.numPressureDofs());
p = 0.0;
contact_.updateForcesAndJacobian(cycle, time, dt, BasePhysics::shapeDisplacement(), displacement_, p);
updateContactForceOutputs();
}

/// @brief Build the quasi-static operator corresponding to the total Lagrangian formulation
Expand Down Expand Up @@ -291,6 +299,10 @@ class SolidMechanicsContact<order, dim, Parameters<parameter_space...>,
/// @overload
const FiniteElementState& dualAdjoint(const std::string& dual_name) const override
{
if (isAggregateContactForceName(dual_name)) {
return contact_force_adjoint_bcs_;
}

const auto interaction_id = parseContactInteractionForceId(dual_name);
if (interaction_id.has_value()) {
auto it = contact_interaction_force_adjoint_bcs_.find(*interaction_id);
Expand All @@ -302,6 +314,36 @@ class SolidMechanicsContact<order, dim, Parameters<parameter_space...>,
return SolidMechanicsBase::dualAdjoint(dual_name);
}

/// @overload
FiniteElementDual loadCheckpointedDual(const std::string& dual_name, int cycle) override
{
if (isAggregateContactForceName(dual_name) || parseContactInteractionForceId(dual_name).has_value()) {
SLIC_ERROR_ROOT_IF(contact_.haveLagrangeMultipliers(),
"Checkpointed retrieval of contact forces is not supported for Lagrange multiplier contact.");

const FiniteElementState checkpointed_displacement = this->loadCheckpointedState("displacement", cycle);
double dt = this->getCheckpointedTimestep(cycle);
contact_.updateForcesAndJacobian(cycle, time_, dt, BasePhysics::shapeDisplacement(), checkpointed_displacement);
updateContactForceOutputs();

if (isAggregateContactForceName(dual_name)) {
return forces_;
}

const auto interaction_id = parseContactInteractionForceId(dual_name);
SLIC_ERROR_ROOT_IF(
!interaction_id.has_value(),
std::format("Requested checkpointed dual '{}' does not exist in physics module '{}'.", dual_name, name_));
auto it = contact_interaction_forces_.find(*interaction_id);
SLIC_ERROR_ROOT_IF(
it == contact_interaction_forces_.end(),
std::format("Requested checkpointed dual '{}' does not exist in physics module '{}'.", dual_name, name_));
return *it->second;
}

return SolidMechanicsBase::loadCheckpointedDual(dual_name, cycle);
}

/**
* @brief create a contactSubspaceTransferOperator for AMGF
*/
Expand All @@ -324,6 +366,7 @@ class SolidMechanicsContact<order, dim, Parameters<parameter_space...>,
double dt = 0.0;
mfem::Vector p = pressure();
contact_.updateForcesAndJacobian(cycle_, time_, dt, BasePhysics::shapeDisplacement(), displacement_, p);
updateContactForceOutputs();

SolidMechanicsBase::completeSetup();
}
Expand Down Expand Up @@ -354,47 +397,50 @@ class SolidMechanicsContact<order, dim, Parameters<parameter_space...>,
}
#endif

protected:
/// @overload
void setDualAdjointBcs(std::unordered_map<std::string, const smith::FiniteElementState&> bcs) override
bool trySetDualAdjointBc(const std::string& dual_name, const smith::FiniteElementState& bc) override
{
SLIC_ERROR_ROOT_IF(bcs.size() == 0, "Adjoint load container size must be greater than 0 in SolidMechanicsContact.");

auto reaction_adjoint_load = bcs.find("reactions");
if (reaction_adjoint_load != bcs.end()) {
SolidMechanicsBase::setDualAdjointBcs({{"reactions", reaction_adjoint_load->second}});
if (isAggregateContactForceName(dual_name)) {
contact_force_adjoint_bcs_ = bc;
return true;
}

for (const auto& [name, bc] : bcs) {
if (name == "reactions") {
continue;
}

const auto interaction_id = parseContactInteractionForceId(name);
SLIC_ERROR_ROOT_IF(!interaction_id.has_value(),
std::format("Unknown dual adjoint BC '{}' for SolidMechanicsContact.", name));

const auto interaction_id = parseContactInteractionForceId(dual_name);
if (interaction_id.has_value()) {
auto it = contact_interaction_force_adjoint_bcs_.find(*interaction_id);
SLIC_ERROR_ROOT_IF(it == contact_interaction_force_adjoint_bcs_.end(),
std::format("No contact force adjoint BC registered for interaction_id={}", *interaction_id));

*it->second = bc;
return true;
}

return SolidMechanicsBase::trySetDualAdjointBc(dual_name, bc);
}

protected:
/// @brief Converts a dual name into an interaction id (if it exists)
static std::optional<int> parseContactInteractionForceId(std::string_view dual_name)
std::optional<int> parseContactInteractionForceId(std::string_view dual_name) const
{
constexpr std::string_view prefix = "contact_force_";
const std::string normalized_name = detail::removePrefix(this->name_, std::string(dual_name));
const std::string_view normalized_name_view{normalized_name};

// Accept both the bare name and the module-prefixed name, e.g. "solid_contact_force_0".
const auto idx = dual_name.rfind(prefix);
if (idx == std::string_view::npos) {
if (normalized_name_view.substr(0, prefix.size()) != prefix) {
return std::nullopt;
}

// This code converts everything after the prefix to a candidate id
const std::string_view id_str = dual_name.substr(idx + prefix.size());
return parseInteractionId(normalized_name_view.substr(prefix.size()));
}

/// @brief Returns true if @p dual_name refers to the merged contact force dual
bool isAggregateContactForceName(std::string_view dual_name) const
{
return detail::removePrefix(this->name_, std::string(dual_name)) == "contact_forces";
}

/// @brief Parses the interaction id from the suffix of a contact force dual name
static std::optional<int> parseInteractionId(std::string_view id_str)
{
int interaction_id = -1;
const auto* begin = id_str.data();
const auto* end = id_str.data() + id_str.size();
Expand Down Expand Up @@ -425,16 +471,7 @@ class SolidMechanicsContact<order, dim, Parameters<parameter_space...>,
// solve the non-linear system resid = 0 and pressure * gap = 0
nonlin_solver_->solve(augmented_solution);
displacement_.Set(1.0, mfem::Vector(augmented_solution, 0, displacement_.Size()));
forces_.SetVector(contact_.forces(), 0);

#ifdef SMITH_USE_TRIBOL
for (const auto& interaction : contact_.getContactInteractions()) {
auto it = contact_interaction_forces_.find(interaction.getInteractionId());
if (it != contact_interaction_forces_.end()) {
it->second->SetVector(interaction.forces(), 0);
}
}
#endif
updateContactForceOutputs();
}

/**
Expand Down Expand Up @@ -606,27 +643,32 @@ class SolidMechanicsContact<order, dim, Parameters<parameter_space...>,
// Following SolidMechanics::setAdjointLoad() sign convention, displacement_adjoint_load_ stores the negative of the
// provided dJ/du, so we subtract these contributions here.
#ifdef SMITH_USE_TRIBOL
if (!contact_interaction_force_adjoint_bcs_.empty()) {
const bool have_aggregate_contact_force_seed = contact_force_adjoint_bcs_.Norml2() != 0.0;
if (!contact_interaction_force_adjoint_bcs_.empty() || have_aggregate_contact_force_seed) {
FiniteElementDual contact_force_load(displacement_.space(), "contact_force_dual_adjoint_load");
contact_force_load = 0.0;

for (const auto& [interaction_id, force_seed] : contact_interaction_force_adjoint_bcs_) {
if (!force_seed) {
continue;
for (const auto& interaction : contact_.getContactInteractions()) {
const int interaction_id = interaction.getInteractionId();
const auto interaction_J = contactInteraction(interaction_id).jacobianContribution();
auto* J00 = dynamic_cast<mfem::HypreParMatrix*>(&interaction_J->GetBlock(0, 0));
SLIC_ERROR_ROOT_IF(!J00, "Expected HypreParMatrix (0,0) block for contact interaction Jacobian.");

if (have_aggregate_contact_force_seed) {
FiniteElementDual tmp(displacement_.space(), "contact_force_dual_adjoint_load_tmp");
tmp = 0.0;
J00->MultTranspose(contact_force_adjoint_bcs_, tmp);
contact_force_load.Add(1.0, tmp);
}

// Only apply if the seed is nonzero.
if (force_seed->Norml2() == 0.0) {
auto it = contact_interaction_force_adjoint_bcs_.find(interaction_id);
if (it == contact_interaction_force_adjoint_bcs_.end() || !it->second || it->second->Norml2() == 0.0) {
continue;
}

const auto interaction_J = contactInteraction(interaction_id).jacobianContribution();
auto* J00 = dynamic_cast<mfem::HypreParMatrix*>(&interaction_J->GetBlock(0, 0));
SLIC_ERROR_ROOT_IF(!J00, "Expected HypreParMatrix (0,0) block for contact interaction Jacobian.");

FiniteElementDual tmp(displacement_.space(), "contact_force_dual_adjoint_load_tmp");
tmp = 0.0;
J00->MultTranspose(*force_seed, tmp);
J00->MultTranspose(*it->second, tmp);
contact_force_load.Add(1.0, tmp);
}

Expand Down Expand Up @@ -668,6 +710,20 @@ class SolidMechanicsContact<order, dim, Parameters<parameter_space...>,
return BasePhysics::shapeDisplacementSensitivity();
}

/// @brief Update the cached contact-force outputs from the current Tribol state
void updateContactForceOutputs()
{
forces_ = contact_.forces();
#ifdef SMITH_USE_TRIBOL
for (const auto& interaction : contact_.getContactInteractions()) {
auto it = contact_interaction_forces_.find(interaction.getInteractionId());
if (it != contact_interaction_forces_.end()) {
*it->second = interaction.forces();
}
}
#endif
}

using BasePhysics::bcs_;
using BasePhysics::cycle_;
using BasePhysics::duals_;
Expand Down Expand Up @@ -726,6 +782,9 @@ class SolidMechanicsContact<order, dim, Parameters<parameter_space...>,
/// forces for output
FiniteElementDual forces_;

/// merged dual-adjoint (BC) field for the total contact force dual
FiniteElementState contact_force_adjoint_bcs_;

/// per-interaction contact forces for output
std::unordered_map<int, std::unique_ptr<FiniteElementDual>> contact_interaction_forces_;

Expand Down
77 changes: 75 additions & 2 deletions src/smith/physics/tests/contact_solid_adjoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,7 @@ double computeSolidMechanicsQoi(BasePhysics& solid_solver, const TimeSteppingInf
{
auto dts = ts_info.dts;
solid_solver.resetStates();
solid_solver.outputStateToDisk("paraview_contact");
solid_solver.advanceTimestep(1.0);
solid_solver.outputStateToDisk("paraview_contact");
return computeStepQoi(solid_solver.state("displacement"));
}

Expand Down Expand Up @@ -419,6 +417,81 @@ TEST_F(ContactSensitivityFixture, ContactForceDualAdjointBcsMatchesEquivalentAdj
EXPECT_NEAR(diff.Norml2(), 0.0, 1.0e-10);
}

TEST_F(ContactSensitivityFixture, AggregateContactForcesDualAdjointBcsMatchesEquivalentAdjointLoad)
{
auto solver = createContactSolver(mesh, nonlinear_opts, dyn_opts, mat);

solver->resetStates();
solver->advanceTimestep(1.0);
EXPECT_EQ(1, solver->cycle());

FiniteElementState dJ_df_total(solver->state("displacement").space(), "dJ_df_total");
fillDirection(*solver, dJ_df_total);

FiniteElementDual zero_load(solver->state("displacement").space(), "zero_load");
zero_load = 0.0;
solver->setAdjointLoad({{"displacement", zero_load}});
solver->setDualAdjointBcs({{"contact_forces", dJ_df_total}});
solver->reverseAdjointTimestep();

FiniteElementState lambda_from_seed(solver->adjoint("displacement"));

solver->resetStates();
solver->advanceTimestep(1.0);
EXPECT_EQ(1, solver->cycle());

FiniteElementDual equivalent_load(solver->state("displacement").space(), "equivalent_load");
equivalent_load = 0.0;
for (const auto& dual_name : solver->dualNames()) {
const std::string prefix = "contact_force_";
if (dual_name.substr(0, prefix.size()) != prefix) {
continue;
}

const int interaction_id = std::stoi(dual_name.substr(prefix.size()));
const auto interaction_jacobian = solver->contactInteraction(interaction_id).jacobianContribution();
auto* J00 = dynamic_cast<mfem::HypreParMatrix*>(&interaction_jacobian->GetBlock(0, 0));
SLIC_ERROR_ROOT_IF(!J00, "Expected HypreParMatrix (0,0) block for contact interaction Jacobian.");

FiniteElementDual interaction_load(solver->state("displacement").space(), "interaction_load");
interaction_load = 0.0;
J00->MultTranspose(dJ_df_total, interaction_load);
equivalent_load.Add(1.0, interaction_load);
}

solver->setAdjointLoad({{"displacement", equivalent_load}});
solver->reverseAdjointTimestep();

const auto& lambda_from_load = solver->adjoint("displacement");

FiniteElementState diff(lambda_from_seed);
diff.Add(-1.0, lambda_from_load);

EXPECT_NEAR(diff.Norml2(), 0.0, 1.0e-10);
}

TEST_F(ContactSensitivityFixture, AggregateContactForcesDualAdjointAcceptsPrefixedName)
{
auto solver = createContactSolver(mesh, nonlinear_opts, dyn_opts, mat);

FiniteElementState seed(solver->state("displacement").space(), "aggregate_contact_force_seed");
fillDirection(*solver, seed);

const auto prefixed_name = solver->dual("contact_forces").name();
solver->setDualAdjointBcs({{prefixed_name, seed}});

const auto& stored_seed = solver->dualAdjoint("contact_forces");
const auto& stored_prefixed_seed = solver->dualAdjoint(prefixed_name);

FiniteElementState diff(stored_seed);
diff.Add(-1.0, seed);
EXPECT_NEAR(diff.Norml2(), 0.0, 1.0e-12);

FiniteElementState prefixed_diff(stored_prefixed_seed);
prefixed_diff.Add(-1.0, seed);
EXPECT_NEAR(prefixed_diff.Norml2(), 0.0, 1.0e-12);
}

} // namespace smith

int main(int argc, char* argv[])
Expand Down
Loading