7 #include "gretl/data_store.hpp"
13 #include "gretl/upstream_state.hpp"
20 gretl::State<int>
make_milestone(
const std::vector<FieldState>& states,
const std::vector<ReactionState>& reactions)
22 std::vector<gretl::StateBase> base_states;
23 for (
const auto& s : states) {
24 base_states.push_back(s);
26 for (
const auto& r : reactions) {
27 base_states.push_back(r);
30 auto milestone = states[0].create_state<int,
int>(base_states);
33 []([[maybe_unused]]
const gretl::UpstreamStates& inputs, gretl::DownstreamState& output) { output.set<
int>(0); });
35 []([[maybe_unused]] gretl::UpstreamStates& inputs, [[maybe_unused]]
const gretl::DownstreamState& output) {});
37 return milestone.finalize();
42 const FieldState& shape_disp,
const std::vector<FieldState>& states,
43 const std::vector<FieldState>& params,
44 std::shared_ptr<StateAdvancer> advancer, std::string mech_name,
45 const std::vector<ReactionInfo>& reaction_infos)
49 reaction_infos_(reaction_infos)
51 SLIC_ERROR_IF(states.size() == 0,
"Must have a least 1 state for a mechanics.");
52 field_shape_displacement_ = std::make_unique<FieldState>(shape_disp);
53 for (
size_t i = 0; i < states.size(); ++i) {
54 const auto& s = states[i];
55 field_states_.push_back(s);
56 initial_field_states_.push_back(s);
57 state_name_to_field_index_[s.get()->name()] = i;
58 state_names_.push_back(s.get()->name());
61 for (
size_t i = 0; i < params.size(); ++i) {
62 const auto& p = params[i];
63 field_params_.push_back(p);
64 param_name_to_field_index_[p.get()->name()] = i;
65 param_names_.push_back(p.get()->name());
68 reaction_names_.reserve(reaction_infos_.size());
69 for (
size_t i = 0; i < reaction_infos_.size(); ++i) {
71 reaction_infos_[i].
space ==
nullptr,
72 axom::fmt::format(
"Dual '{}' in physics module '{}' has null FE space.", reaction_infos_[i].
name,
name_));
73 reaction_names_.push_back(reaction_infos_[i].
name);
74 reaction_name_to_reaction_index_[reaction_infos_[i].name] = i;
82 SLIC_ERROR_IF(field_states_.empty(),
"Empty field state during completeSetup()");
83 initializeReactionStates();
88 for (
size_t i = 0; i < initial_field_states_.size(); ++i) {
89 field_states_[i] = initial_field_states_[i];
92 checkpointer_->reset_graph();
93 initializeReactionStates();
100 checkpointer_->finalize_graph();
101 checkpointer_->reset_for_backprop();
102 gretl_assert(checkpointer_->check_validity());
114 state_name_to_field_index_.find(field_name) == state_name_to_field_index_.end(),
115 std::format(
"Could not find field named {0} in mesh with tag \"{1}\" to get", field_name,
mesh_->tag()));
116 size_t state_index = state_name_to_field_index_.at(field_name);
117 return *field_states_[state_index].get();
122 SLIC_ERROR_IF(reaction_name_to_reaction_index_.find(reaction_name) == reaction_name_to_reaction_index_.end(),
123 axom::fmt::format(
"Could not find reaction named {0} in mesh with tag \"{1}\" to get", reaction_name,
125 size_t reaction_index = reaction_name_to_reaction_index_.at(reaction_name);
127 SLIC_ERROR_IF(reaction_states_.empty() && !reaction_names_.empty(),
128 "Reactions were not computed during advanceState, but were requested.");
131 reaction_index >= reaction_states_.size(),
132 "Reaction reactions not correctly allocated yet, cannot get reaction until after initializationStep is called.");
134 return *reaction_states_[reaction_index].get();
141 axom::fmt::format(
"Due to checkpointing restrictions in smith::DifferentiablePhysics, cannot ask for "
142 "an arbitrary checkpointed reaction cycle, asking for cycle {}, but physics is at cycle {}",
144 return dual(reaction_name);
150 std::format(
"Due to checkpointing restrictions in smith::Mechanics, cannot ask for an arbitrary "
151 "checkpointed cycle, asking for cycle {}, but physics is at cycle {}",
153 return state(state_name);
160 SLIC_ERROR_IF(parameter_index >= field_params_.size(),
161 std::format(
"Parameter index {} requested, but only {} parameters exist in physics module {}.",
162 parameter_index, field_params_.size(),
name_));
163 return *field_params_[parameter_index].get();
169 param_name_to_field_index_.find(parameter_name) == param_name_to_field_index_.end(),
170 std::format(
"Could not find parameter named {0} in mesh with tag \"{1}\" to get", parameter_name,
mesh_->tag()));
171 size_t param_index = param_name_to_field_index_.at(parameter_name);
177 SLIC_ERROR_IF(parameter_index >= field_params_.size(),
178 std::format(
"Parameter '{}' requested when only '{}' parameters exist in physics module '{}'",
179 parameter_index, field_params_.size(),
name_));
180 *field_params_[parameter_index].get() = parameter_state;
185 *field_shape_displacement_->get() = shape_displacement;
191 SLIC_ERROR_IF(state_name_to_field_index_.find(field_name) == state_name_to_field_index_.end(),
192 std::format(
"Could not find field named {0} in mesh with tag {1} to set", field_name,
mesh_->tag()));
193 size_t state_index = state_name_to_field_index_.at(field_name);
194 *field_states_[state_index].get() = s;
195 *initial_field_states_[state_index].get() = s;
199 std::unordered_map<std::string, const smith::FiniteElementDual&> string_to_reaction)
201 for (
auto string_reaction_pair : string_to_reaction) {
202 std::string field_name = string_reaction_pair.first;
205 state_name_to_field_index_.find(field_name) == state_name_to_field_index_.end(),
206 axom::fmt::format(
"Could not find reaction named {0} in mesh with tag {1}", field_name,
mesh_->tag()));
207 size_t state_index = state_name_to_field_index_.at(field_name);
208 *field_states_[state_index].get_dual() += reaction;
213 std::unordered_map<std::string, const smith::FiniteElementState&> string_to_bc)
215 for (
auto string_bc_pair : string_to_bc) {
216 std::string reaction_name = string_bc_pair.first;
219 reaction_name_to_reaction_index_.find(reaction_name) == reaction_name_to_reaction_index_.end(),
220 axom::fmt::format(
"When calling setDualAdjointBcs, could not find reaction named {0} in mesh with tag {1}",
221 reaction_name,
mesh_->tag()));
222 size_t reaction_index = reaction_name_to_reaction_index_.at(reaction_name);
223 *reaction_states_[reaction_index].get_dual() += reaction_adjoint_load;
230 SLIC_ERROR(
"What is the use case for asking for the adjoint solution field directly?");
237 field_states_ = initial_field_states_;
238 milestones_.push_back(
make_milestone(field_states_, reaction_states_).step());
246 auto [states, reactions] =
247 advancer_->advanceState(time_info, *field_shape_displacement_, field_states_, field_params_);
248 field_states_ = states;
249 reaction_states_ = reactions;
253 milestones_.push_back(
make_milestone(field_states_, reaction_states_).step());
259 const gretl::Int milestone = milestones_[
static_cast<size_t>(
cycle_)];
261 field_shape_displacement_->clear_dual();
262 for (
auto& p : field_params_) {
266 gretl::Int current_step = checkpointer_->currentStep_;
267 while (milestone != current_step) {
268 checkpointer_->reverse_state();
269 current_step = checkpointer_->currentStep_;
272 gretl::UpstreamStates upstreams(*checkpointer_, checkpointer_->upstreamSteps_[milestone]);
274 const size_t expected_upstreams = field_states_.size() + reaction_states_.size();
275 SLIC_ERROR_IF(expected_upstreams != upstreams.size(),
"field/reaction states and upstream sizes do not match.");
277 for (
size_t s = 0; s < field_states_.size(); ++s) {
278 field_states_[s].reset_step(upstreams[s].step_);
279 field_states_[s].set(upstreams[s].get<FEFieldPtr>());
280 field_states_[s].set_dual(upstreams[s].get_dual<FEDualPtr, FEFieldPtr>());
282 for (
size_t r = 0; r < reaction_states_.size(); ++r) {
283 const size_t upstream_index = field_states_.size() + r;
284 reaction_states_[r].reset_step(upstreams[upstream_index].step_);
285 reaction_states_[r].set(upstreams[upstream_index].get<FEDualPtr>());
286 reaction_states_[r].set_dual(upstreams[upstream_index].get_dual<FEFieldPtr, FEDualPtr>());
292 return *field_params_[parameter_index].get_dual();
297 return *field_shape_displacement_->get_dual();
300 const std::unordered_map<std::string, const smith::FiniteElementDual&>
303 std::unordered_map<std::string, const smith::FiniteElementDual&> map;
305 auto state_index = state_name_to_field_index_.at(
name);
306 map.insert({
name, *initial_field_states_[state_index].get_dual()});
313 std::vector<FieldState> fields;
314 fields.insert(fields.end(), field_states_.begin(), field_states_.end());
315 fields.insert(fields.end(), field_params_.begin(), field_params_.end());
321 void DifferentiablePhysics::initializeReactionStates()
323 reaction_states_.clear();
324 reaction_states_.reserve(reaction_infos_.size());
325 for (
const auto& reaction_info : reaction_infos_) {
326 auto reaction = std::make_shared<FiniteElementDual>(*reaction_info.space, reaction_info.name);
This is the abstract base class for a generic forward solver.
std::string name_
Name of the physics module.
std::shared_ptr< smith::Mesh > mesh_
The primary mesh.
int cycle_
Current cycle (forward pass time iteration count)
std::string name() const
Return the name of the physics.
virtual double time() const
Get the current forward-solution time.
double time_
Current time for the forward pass.
virtual int cycle() const
Get the current forward-solution cycle iteration number.
std::vector< const smith::FiniteElementState * > adjoints_
List of finite element adjoint states associated with this physics module.
void resetStates(int cycle=0, double time=0.0) override
Base method to reset physics states to the initial time. This does not reset design parameters or sha...
FiniteElementDual loadCheckpointedDual(const std::string &state_name, int cycle) override
Accessor for getting a single named finite element dual solution from the physics modules at a given ...
void reverseAdjointTimestep() override
Reverse one recorded timestep through the gretl graph.
void completeSetup() override
Complete the setup and allocate the necessary data structures.
std::vector< std::string > parameterNames() const override
Get a vector of the finite element state parameter names.
void setShapeDisplacement(const FiniteElementState &shape_displacement) override
Set the current shape displacement for the underlying mesh.
FieldState getShapeDispFieldState() const
Get the tracked shape displacement field.
const FiniteElementState & adjoint(const std::string &adjoint_name) const override
Accessor for getting named finite element state adjoint solution from the physics modules.
void setDualAdjointBcs(std::unordered_map< std::string, const smith::FiniteElementState & > string_to_bc) override
Set the dual loads (dirichlet values) for the adjoint reverse timestep solve This must be called afte...
std::vector< std::string > dualNames() const override
Get a vector of the finite element state dual (reaction) solution names.
FiniteElementDual computeTimestepSensitivity(size_t parameter_index) override
Compute the implicit sensitivity of the quantity of interest used in defining the adjoint load with r...
virtual void resetAdjointStates() override
Base method to reset physics states back to the end of time to start adjoint calculations again....
DifferentiablePhysics(std::shared_ptr< Mesh > mesh, std::shared_ptr< gretl::DataStore > graph, const FieldState &shape_disp, const std::vector< FieldState > &states, const std::vector< FieldState > ¶ms, std::shared_ptr< StateAdvancer > advancer, std::string physics_name, const std::vector< ReactionInfo > &reaction_infos={})
Construct a differentiable physics wrapper around a state advancer and its tracked fields.
void setAdjointLoad(std::unordered_map< std::string, const smith::FiniteElementDual & > string_to_dual) override
Set the loads for the adjoint reverse timestep solve.
const FiniteElementDual & computeTimestepShapeSensitivity() override
Compute the implicit sensitivity of the quantity of interest used in defining the adjoint load with r...
const FiniteElementState & parameter(std::size_t parameter_index) const override
Accessor for getting indexed finite element state parameter fields from the physics modules.
void setState(const std::string &state_name, const FiniteElementState &s) override
Set the primal solution field values of the underlying physics solver.
const FiniteElementState & shapeDisplacement() const override
Accessor for getting the shape displacement field from the physics modules.
FiniteElementState loadCheckpointedState(const std::string &state_name, int cycle) override
Return a state for a stored checkpoint cycle.
std::vector< FieldState > getFieldStatesAndParamStates() const
Get the tracked state fields followed by the tracked parameter fields.
virtual void advanceTimestep(double dt) override
Advance the state variables according to the chosen time integrator.
const FiniteElementState & state(const std::string &state_name) const override
Accessor for getting named finite element state primal solution from the physics modules.
std::vector< std::string > stateNames() const override
Get a vector of the finite element state primal solution names.
const FiniteElementDual & dual(const std::string &dual_name) const override
Accessor for getting named finite element state dual (reaction) solution from the physics modules.
const std::unordered_map< std::string, const smith::FiniteElementDual & > computeInitialConditionSensitivity() const override
Compute the implicit sensitivity of the quantity of interest with respect to the initial condition fi...
void setParameter(const size_t parameter_index, const FiniteElementState ¶meter_state) override
Deep copy a parameter field into the internally-owned parameter used for simulations.
Class for encapsulating the dual vector space of a finite element space (i.e. the space of linear for...
Class for encapsulating the critical MFEM components of a primal finite element field.
Defines a BasePhysics implementation backed by FieldState objects and a gretl computational graph.
Smith mesh class which assists in constructing the appropriate parallel mfem meshes and registering a...
Accelerator functionality.
gretl::State< FEFieldPtr, FEDualPtr > FieldState
typedef
gretl::State< int > make_milestone(const std::vector< FieldState > &states, const std::vector< ReactionState > &reactions)
gretl-function to create a dummy-state which records all states and params of interest to the mechani...
ReactionState createReactionState(gretl::DataStore &dataStore, const smith::FEDualPtr &s)
initialize on the gretl::DataStore a ReactionState with values from s
mfem::ParFiniteElementSpace & space(FieldState field)
Get the space from the primal field of a field states.
Reaction class which is a names combination of a weak form and a set of dirichlet constrained nodes.
Interface and implementations for advancing from one step to the next. Typically these are time integ...
struct storing time and timestep information