7 #include "gretl/data_store.hpp"
13 #include "gretl/upstream_state.hpp"
20 gretl::State<int>
make_milestone(
const std::vector<FieldState>& states,
const std::vector<ReactionState>& reactions)
22 std::vector<gretl::StateBase> base_states;
23 for (
const auto& s : states) {
24 base_states.push_back(s);
26 for (
const auto& r : reactions) {
27 base_states.push_back(r);
30 auto milestone = states[0].create_state<int,
int>(base_states);
33 []([[maybe_unused]]
const gretl::UpstreamStates& inputs, gretl::DownstreamState& output) { output.set<
int>(0); });
35 []([[maybe_unused]] gretl::UpstreamStates& inputs, [[maybe_unused]]
const gretl::DownstreamState& output) {});
37 return milestone.finalize();
42 const FieldState& shape_disp,
const std::vector<FieldState>& states,
43 const std::vector<FieldState>& params,
44 std::shared_ptr<StateAdvancer> advancer, std::string mech_name,
45 const std::vector<ReactionInfo>& reaction_infos)
49 reaction_infos_(reaction_infos)
51 SLIC_ERROR_IF(states.size() == 0,
"Must have a least 1 state for a mechanics.");
52 field_shape_displacement_ = std::make_unique<FieldState>(shape_disp);
53 for (
size_t i = 0; i < states.size(); ++i) {
54 const auto& s = states[i];
55 field_states_.push_back(s);
56 initial_field_states_.push_back(s);
57 state_name_to_field_index_[s.get()->name()] = i;
58 state_names_.push_back(s.get()->name());
61 for (
size_t i = 0; i < params.size(); ++i) {
62 const auto& p = params[i];
63 field_params_.push_back(p);
64 param_name_to_field_index_[p.get()->name()] = i;
65 param_names_.push_back(p.get()->name());
68 reaction_names_.reserve(reaction_infos_.size());
69 for (
size_t i = 0; i < reaction_infos_.size(); ++i) {
71 reaction_infos_[i].
space ==
nullptr,
72 axom::fmt::format(
"Dual '{}' in physics module '{}' has null FE space.", reaction_infos_[i].
name,
name_));
73 reaction_names_.push_back(reaction_infos_[i].
name);
74 reaction_name_to_reaction_index_[reaction_infos_[i].name] = i;
82 SLIC_ERROR_IF(field_states_.empty(),
"Empty field state during completeSetup()");
83 initializeReactionStates();
88 for (
size_t i = 0; i < initial_field_states_.size(); ++i) {
89 field_states_[i] = initial_field_states_[i];
92 checkpointer_->reset_graph();
93 initializeReactionStates();
100 checkpointer_->finalize_graph();
101 checkpointer_->reset_for_backprop();
102 gretl_assert(checkpointer_->check_validity());
114 state_name_to_field_index_.find(field_name) == state_name_to_field_index_.end(),
115 std::format(
"Could not find field named {0} in mesh with tag \"{1}\" to get", field_name,
mesh_->tag()));
116 size_t state_index = state_name_to_field_index_.at(field_name);
117 return *field_states_[state_index].get();
122 if (reaction_name_to_reaction_index_.find(reaction_name) == reaction_name_to_reaction_index_.end()) {
123 std::string available;
124 for (
auto& n : reaction_names_) available += n +
" ";
125 SLIC_ERROR(axom::fmt::format(
126 "Could not find reaction named {0} in mesh with tag \"{1}\" to get. Available reactions (size {2}): {3}",
127 reaction_name,
mesh_->tag(), reaction_names_.size(), available));
129 size_t reaction_index = reaction_name_to_reaction_index_.at(reaction_name);
131 SLIC_ERROR_IF(reaction_states_.empty() && !reaction_names_.empty(),
132 "Reactions were not computed during advanceState, but were requested.");
135 reaction_index >= reaction_states_.size(),
136 "Reaction reactions not correctly allocated yet, cannot get reaction until after initializationStep is called.");
138 return *reaction_states_[reaction_index].get();
145 axom::fmt::format(
"Due to checkpointing restrictions in smith::DifferentiablePhysics, cannot ask for "
146 "an arbitrary checkpointed reaction cycle, asking for cycle {}, but physics is at cycle {}",
148 return dual(reaction_name);
154 std::format(
"Due to checkpointing restrictions in smith::Mechanics, cannot ask for an arbitrary "
155 "checkpointed cycle, asking for cycle {}, but physics is at cycle {}",
157 return state(state_name);
164 SLIC_ERROR_IF(parameter_index >= field_params_.size(),
165 std::format(
"Parameter index {} requested, but only {} parameters exist in physics module {}.",
166 parameter_index, field_params_.size(),
name_));
167 return *field_params_[parameter_index].get();
173 param_name_to_field_index_.find(parameter_name) == param_name_to_field_index_.end(),
174 std::format(
"Could not find parameter named {0} in mesh with tag \"{1}\" to get", parameter_name,
mesh_->tag()));
175 size_t param_index = param_name_to_field_index_.at(parameter_name);
181 SLIC_ERROR_IF(parameter_index >= field_params_.size(),
182 std::format(
"Parameter '{}' requested when only '{}' parameters exist in physics module '{}'",
183 parameter_index, field_params_.size(),
name_));
184 *field_params_[parameter_index].get() = parameter_state;
189 *field_shape_displacement_->get() = shape_displacement;
195 SLIC_ERROR_IF(state_name_to_field_index_.find(field_name) == state_name_to_field_index_.end(),
196 std::format(
"Could not find field named {0} in mesh with tag {1} to set", field_name,
mesh_->tag()));
197 size_t state_index = state_name_to_field_index_.at(field_name);
198 *field_states_[state_index].get() = s;
199 *initial_field_states_[state_index].get() = s;
203 std::unordered_map<std::string, const smith::FiniteElementDual&> string_to_reaction)
205 for (
auto string_reaction_pair : string_to_reaction) {
206 std::string field_name = string_reaction_pair.first;
209 state_name_to_field_index_.find(field_name) == state_name_to_field_index_.end(),
210 axom::fmt::format(
"Could not find reaction named {0} in mesh with tag {1}", field_name,
mesh_->tag()));
211 size_t state_index = state_name_to_field_index_.at(field_name);
212 *field_states_[state_index].get_dual() += reaction;
217 std::unordered_map<std::string, const smith::FiniteElementState&> string_to_bc)
219 for (
auto string_bc_pair : string_to_bc) {
220 std::string reaction_name = string_bc_pair.first;
223 reaction_name_to_reaction_index_.find(reaction_name) == reaction_name_to_reaction_index_.end(),
224 axom::fmt::format(
"When calling setDualAdjointBcs, could not find reaction named {0} in mesh with tag {1}",
225 reaction_name,
mesh_->tag()));
226 size_t reaction_index = reaction_name_to_reaction_index_.at(reaction_name);
227 *reaction_states_[reaction_index].get_dual() += reaction_adjoint_load;
234 SLIC_ERROR(
"What is the use case for asking for the adjoint solution field directly?");
241 field_states_ = initial_field_states_;
242 milestones_.push_back(
make_milestone(field_states_, reaction_states_).step());
250 auto [states, reactions] =
251 advancer_->advanceState(time_info, *field_shape_displacement_, field_states_, field_params_);
252 field_states_ = states;
253 reaction_states_ = reactions;
257 milestones_.push_back(
make_milestone(field_states_, reaction_states_).step());
263 const gretl::Int milestone = milestones_[
static_cast<size_t>(
cycle_)];
265 field_shape_displacement_->clear_dual();
266 for (
auto& p : field_params_) {
270 gretl::Int current_step = checkpointer_->currentStep_;
271 while (milestone != current_step) {
272 checkpointer_->reverse_state();
273 current_step = checkpointer_->currentStep_;
276 gretl::UpstreamStates upstreams(*checkpointer_, checkpointer_->upstreamSteps_[milestone]);
278 const size_t expected_upstreams = field_states_.size() + reaction_states_.size();
279 SLIC_ERROR_IF(expected_upstreams != upstreams.size(),
"field/reaction states and upstream sizes do not match.");
281 for (
size_t s = 0; s < field_states_.size(); ++s) {
282 field_states_[s].reset_step(upstreams[s].step_);
283 field_states_[s].set(upstreams[s].get<FEFieldPtr>());
284 field_states_[s].set_dual(upstreams[s].get_dual<FEDualPtr, FEFieldPtr>());
286 for (
size_t r = 0; r < reaction_states_.size(); ++r) {
287 const size_t upstream_index = field_states_.size() + r;
288 reaction_states_[r].reset_step(upstreams[upstream_index].step_);
289 reaction_states_[r].set(upstreams[upstream_index].get<FEDualPtr>());
290 reaction_states_[r].set_dual(upstreams[upstream_index].get_dual<FEFieldPtr, FEDualPtr>());
296 return *field_params_[parameter_index].get_dual();
301 return *field_shape_displacement_->get_dual();
304 const std::unordered_map<std::string, const smith::FiniteElementDual&>
307 std::unordered_map<std::string, const smith::FiniteElementDual&> map;
309 auto state_index = state_name_to_field_index_.at(
name);
310 map.insert({
name, *initial_field_states_[state_index].get_dual()});
317 std::vector<FieldState> fields;
318 fields.insert(fields.end(), field_states_.begin(), field_states_.end());
319 fields.insert(fields.end(), field_params_.begin(), field_params_.end());
325 SLIC_ERROR_IF(state_name_to_field_index_.find(state_name) == state_name_to_field_index_.end(),
326 std::format(
"Could not find initial field named {0}", state_name));
327 return initial_field_states_[state_name_to_field_index_.at(state_name)];
332 SLIC_ERROR_IF(state_name_to_field_index_.find(state_name) == state_name_to_field_index_.end(),
333 std::format(
"Could not find field named {0}", state_name));
334 return field_states_[state_name_to_field_index_.at(state_name)];
339 SLIC_ERROR_IF(param_name_to_field_index_.find(param_name) == param_name_to_field_index_.end(),
340 std::format(
"Could not find parameter named {0}", param_name));
341 return field_params_[param_name_to_field_index_.at(param_name)];
346 void DifferentiablePhysics::initializeReactionStates()
348 reaction_states_.clear();
349 reaction_states_.reserve(reaction_infos_.size());
350 for (
const auto& reaction_info : reaction_infos_) {
351 auto reaction = std::make_shared<FiniteElementDual>(*reaction_info.space, reaction_info.name);
This is the abstract base class for a generic forward solver.
std::string name_
Name of the physics module.
std::shared_ptr< smith::Mesh > mesh_
The primary mesh.
int cycle_
Current cycle (forward pass time iteration count)
std::string name() const
Return the name of the physics.
virtual double time() const
Get the current forward-solution time.
double time_
Current time for the forward pass.
virtual int cycle() const
Get the current forward-solution cycle iteration number.
std::vector< const smith::FiniteElementState * > adjoints_
List of finite element adjoint states associated with this physics module.
void resetStates(int cycle=0, double time=0.0) override
Base method to reset physics states to the initial time. This does not reset design parameters or sha...
FiniteElementDual loadCheckpointedDual(const std::string &state_name, int cycle) override
Accessor for getting a single named finite element dual solution from the physics modules at a given ...
FieldState getFieldParam(const std::string ¶m_name) const
Get a tracked parameter field by name.
void reverseAdjointTimestep() override
Reverse one recorded timestep through the gretl graph.
void completeSetup() override
Complete the setup and allocate the necessary data structures.
std::vector< std::string > parameterNames() const override
Get a vector of the finite element state parameter names.
FieldState getFieldState(const std::string &state_name) const
Get a tracked current state field by name.
void setShapeDisplacement(const FiniteElementState &shape_displacement) override
Set the current shape displacement for the underlying mesh.
FieldState getShapeDispFieldState() const
Get the tracked shape displacement field.
const FiniteElementState & adjoint(const std::string &adjoint_name) const override
Accessor for getting named finite element state adjoint solution from the physics modules.
void setDualAdjointBcs(std::unordered_map< std::string, const smith::FiniteElementState & > string_to_bc) override
Set the dual loads (dirichlet values) for the adjoint reverse timestep solve This must be called afte...
std::vector< std::string > dualNames() const override
Get a vector of the finite element state dual (reaction) solution names.
FiniteElementDual computeTimestepSensitivity(size_t parameter_index) override
Compute the implicit sensitivity of the quantity of interest used in defining the adjoint load with r...
virtual void resetAdjointStates() override
Base method to reset physics states back to the end of time to start adjoint calculations again....
DifferentiablePhysics(std::shared_ptr< Mesh > mesh, std::shared_ptr< gretl::DataStore > graph, const FieldState &shape_disp, const std::vector< FieldState > &states, const std::vector< FieldState > ¶ms, std::shared_ptr< StateAdvancer > advancer, std::string physics_name, const std::vector< ReactionInfo > &reaction_infos={})
Construct a differentiable physics wrapper around a state advancer and its tracked fields.
void setAdjointLoad(std::unordered_map< std::string, const smith::FiniteElementDual & > string_to_dual) override
Set the loads for the adjoint reverse timestep solve.
const FiniteElementDual & computeTimestepShapeSensitivity() override
Compute the implicit sensitivity of the quantity of interest used in defining the adjoint load with r...
const FiniteElementState & parameter(std::size_t parameter_index) const override
Accessor for getting indexed finite element state parameter fields from the physics modules.
void setState(const std::string &state_name, const FiniteElementState &s) override
Set the primal solution field values of the underlying physics solver.
FieldState getInitialFieldState(const std::string &state_name) const
Get a tracked initial state field by name.
const FiniteElementState & shapeDisplacement() const override
Accessor for getting the shape displacement field from the physics modules.
FiniteElementState loadCheckpointedState(const std::string &state_name, int cycle) override
Return a state for a stored checkpoint cycle.
std::vector< FieldState > getFieldStatesAndParamStates() const
Get the tracked state fields followed by the tracked parameter fields.
virtual void advanceTimestep(double dt) override
Advance the state variables according to the chosen time integrator.
const FiniteElementState & state(const std::string &state_name) const override
Accessor for getting named finite element state primal solution from the physics modules.
std::vector< std::string > stateNames() const override
Get a vector of the finite element state primal solution names.
const FiniteElementDual & dual(const std::string &dual_name) const override
Accessor for getting named finite element state dual (reaction) solution from the physics modules.
const std::unordered_map< std::string, const smith::FiniteElementDual & > computeInitialConditionSensitivity() const override
Compute the implicit sensitivity of the quantity of interest with respect to the initial condition fi...
void setParameter(const size_t parameter_index, const FiniteElementState ¶meter_state) override
Deep copy a parameter field into the internally-owned parameter used for simulations.
Class for encapsulating the dual vector space of a finite element space (i.e. the space of linear for...
Class for encapsulating the critical MFEM components of a primal finite element field.
Defines a BasePhysics implementation backed by FieldState objects and a gretl computational graph.
Smith mesh class which assists in constructing the appropriate parallel mfem meshes and registering a...
Accelerator functionality.
gretl::State< FEFieldPtr, FEDualPtr > FieldState
typedef
gretl::State< int > make_milestone(const std::vector< FieldState > &states, const std::vector< ReactionState > &reactions)
gretl-function to create a dummy-state which records all states and params of interest to the mechani...
ReactionState createReactionState(gretl::DataStore &dataStore, const smith::FEDualPtr &s)
initialize on the gretl::DataStore a ReactionState with values from s
mfem::ParFiniteElementSpace & space(FieldState field)
Get the space from the primal field of a field states.
Reaction class which is a names combination of a weak form and a set of dirichlet constrained nodes.
Interface and implementations for advancing from one step to the next. Typically these are time integ...
struct storing time and timestep information