Smith  0.1
Smith is an implicit thermal structural mechanics simulation code.
differentiable_physics.cpp
1 // Copyright (c) Lawrence Livermore National Security, LLC and
2 // other Smith Project Developers. See the top-level LICENSE file for
3 // details.
4 //
5 // SPDX-License-Identifier: (BSD-3-Clause)
6 
7 #include "gretl/data_store.hpp"
10 #include "smith/physics/mesh.hpp"
13 #include "gretl/upstream_state.hpp"
14 
15 namespace smith {
16 
20 gretl::State<int> make_milestone(const std::vector<FieldState>& states, const std::vector<ReactionState>& reactions)
21 {
22  std::vector<gretl::StateBase> base_states;
23  for (const auto& s : states) {
24  base_states.push_back(s);
25  }
26  for (const auto& r : reactions) {
27  base_states.push_back(r);
28  }
29 
30  auto milestone = states[0].create_state<int, int>(base_states);
31 
32  milestone.set_eval(
33  []([[maybe_unused]] const gretl::UpstreamStates& inputs, gretl::DownstreamState& output) { output.set<int>(0); });
34  milestone.set_vjp(
35  []([[maybe_unused]] gretl::UpstreamStates& inputs, [[maybe_unused]] const gretl::DownstreamState& output) {});
36 
37  return milestone.finalize();
38 }
39 
40 // mesh, equation, fields, parameters, state advancer, solver
41 DifferentiablePhysics::DifferentiablePhysics(std::shared_ptr<Mesh> mesh, std::shared_ptr<gretl::DataStore> graph,
42  const FieldState& shape_disp, const std::vector<FieldState>& states,
43  const std::vector<FieldState>& params,
44  std::shared_ptr<StateAdvancer> advancer, std::string mech_name,
45  const std::vector<ReactionInfo>& reaction_infos)
46  : BasePhysics(mech_name, mesh, 0, 0.0, false), // the false is checkpoint_to_disk
47  checkpointer_(graph),
48  advancer_(advancer),
49  reaction_infos_(reaction_infos)
50 {
51  SLIC_ERROR_IF(states.size() == 0, "Must have a least 1 state for a mechanics.");
52  field_shape_displacement_ = std::make_unique<FieldState>(shape_disp);
53  for (size_t i = 0; i < states.size(); ++i) {
54  const auto& s = states[i];
55  field_states_.push_back(s);
56  initial_field_states_.push_back(s);
57  state_name_to_field_index_[s.get()->name()] = i;
58  state_names_.push_back(s.get()->name());
59  }
60 
61  for (size_t i = 0; i < params.size(); ++i) {
62  const auto& p = params[i];
63  field_params_.push_back(p);
64  param_name_to_field_index_[p.get()->name()] = i;
65  param_names_.push_back(p.get()->name());
66  }
67 
68  reaction_names_.reserve(reaction_infos_.size());
69  for (size_t i = 0; i < reaction_infos_.size(); ++i) {
70  SLIC_ERROR_IF(
71  reaction_infos_[i].space == nullptr,
72  axom::fmt::format("Dual '{}' in physics module '{}' has null FE space.", reaction_infos_[i].name, name_));
73  reaction_names_.push_back(reaction_infos_[i].name);
74  reaction_name_to_reaction_index_[reaction_infos_[i].name] = i;
75  }
76 
77  completeSetup();
78 }
79 
81 {
82  SLIC_ERROR_IF(field_states_.empty(), "Empty field state during completeSetup()");
83  initializeReactionStates();
84 }
85 
86 void DifferentiablePhysics::resetStates(int cycle, double time)
87 {
88  for (size_t i = 0; i < initial_field_states_.size(); ++i) {
89  field_states_[i] = initial_field_states_[i];
90  }
91  milestones_.clear();
92  checkpointer_->reset_graph();
93  initializeReactionStates();
94  time_ = time;
95  cycle_ = cycle;
96 }
97 
99 {
100  checkpointer_->finalize_graph();
101  checkpointer_->reset_for_backprop();
102  gretl_assert(checkpointer_->check_validity());
103 }
104 
105 std::vector<std::string> DifferentiablePhysics::stateNames() const { return state_names_; }
106 
107 std::vector<std::string> DifferentiablePhysics::parameterNames() const { return param_names_; }
108 
109 std::vector<std::string> DifferentiablePhysics::dualNames() const { return reaction_names_; }
110 
111 const FiniteElementState& DifferentiablePhysics::state([[maybe_unused]] const std::string& field_name) const
112 {
113  SLIC_ERROR_IF(
114  state_name_to_field_index_.find(field_name) == state_name_to_field_index_.end(),
115  std::format("Could not find field named {0} in mesh with tag \"{1}\" to get", field_name, mesh_->tag()));
116  size_t state_index = state_name_to_field_index_.at(field_name);
117  return *field_states_[state_index].get();
118 }
119 
120 const FiniteElementDual& DifferentiablePhysics::dual(const std::string& reaction_name) const
121 {
122  SLIC_ERROR_IF(reaction_name_to_reaction_index_.find(reaction_name) == reaction_name_to_reaction_index_.end(),
123  axom::fmt::format("Could not find reaction named {0} in mesh with tag \"{1}\" to get", reaction_name,
124  mesh_->tag()));
125  size_t reaction_index = reaction_name_to_reaction_index_.at(reaction_name);
126 
127  SLIC_ERROR_IF(reaction_states_.empty() && !reaction_names_.empty(),
128  "Reactions were not computed during advanceState, but were requested.");
129 
130  SLIC_ERROR_IF(
131  reaction_index >= reaction_states_.size(),
132  "Reaction reactions not correctly allocated yet, cannot get reaction until after initializationStep is called.");
133 
134  return *reaction_states_[reaction_index].get();
135 }
136 
137 FiniteElementDual DifferentiablePhysics::loadCheckpointedDual(const std::string& reaction_name, int cycle)
138 {
139  SLIC_ERROR_IF(
140  cycle != cycle_,
141  axom::fmt::format("Due to checkpointing restrictions in smith::DifferentiablePhysics, cannot ask for "
142  "an arbitrary checkpointed reaction cycle, asking for cycle {}, but physics is at cycle {}",
143  cycle, cycle_));
144  return dual(reaction_name);
145 }
146 
147 FiniteElementState DifferentiablePhysics::loadCheckpointedState(const std::string& state_name, int cycle)
148 {
149  SLIC_ERROR_IF(cycle != cycle_,
150  std::format("Due to checkpointing restrictions in smith::Mechanics, cannot ask for an arbitrary "
151  "checkpointed cycle, asking for cycle {}, but physics is at cycle {}",
152  cycle, cycle_));
153  return state(state_name);
154 }
155 
156 const FiniteElementState& DifferentiablePhysics::shapeDisplacement() const { return *field_shape_displacement_->get(); }
157 
158 const FiniteElementState& DifferentiablePhysics::parameter(std::size_t parameter_index) const
159 {
160  SLIC_ERROR_IF(parameter_index >= field_params_.size(),
161  std::format("Parameter index {} requested, but only {} parameters exist in physics module {}.",
162  parameter_index, field_params_.size(), name_));
163  return *field_params_[parameter_index].get();
164 }
165 
166 const FiniteElementState& DifferentiablePhysics::parameter(const std::string& parameter_name) const
167 {
168  SLIC_ERROR_IF(
169  param_name_to_field_index_.find(parameter_name) == param_name_to_field_index_.end(),
170  std::format("Could not find parameter named {0} in mesh with tag \"{1}\" to get", parameter_name, mesh_->tag()));
171  size_t param_index = param_name_to_field_index_.at(parameter_name);
172  return parameter(param_index);
173 }
174 
175 void DifferentiablePhysics::setParameter(const size_t parameter_index, const FiniteElementState& parameter_state)
176 {
177  SLIC_ERROR_IF(parameter_index >= field_params_.size(),
178  std::format("Parameter '{}' requested when only '{}' parameters exist in physics module '{}'",
179  parameter_index, field_params_.size(), name_));
180  *field_params_[parameter_index].get() = parameter_state;
181 }
182 
184 {
185  *field_shape_displacement_->get() = shape_displacement;
186 }
187 
188 void DifferentiablePhysics::setState([[maybe_unused]] const std::string& field_name,
189  [[maybe_unused]] const FiniteElementState& s)
190 {
191  SLIC_ERROR_IF(state_name_to_field_index_.find(field_name) == state_name_to_field_index_.end(),
192  std::format("Could not find field named {0} in mesh with tag {1} to set", field_name, mesh_->tag()));
193  size_t state_index = state_name_to_field_index_.at(field_name);
194  *field_states_[state_index].get() = s;
195  *initial_field_states_[state_index].get() = s;
196 }
197 
199  std::unordered_map<std::string, const smith::FiniteElementDual&> string_to_reaction)
200 {
201  for (auto string_reaction_pair : string_to_reaction) {
202  std::string field_name = string_reaction_pair.first;
203  const smith::FiniteElementDual& reaction = string_reaction_pair.second;
204  SLIC_ERROR_IF(
205  state_name_to_field_index_.find(field_name) == state_name_to_field_index_.end(),
206  axom::fmt::format("Could not find reaction named {0} in mesh with tag {1}", field_name, mesh_->tag()));
207  size_t state_index = state_name_to_field_index_.at(field_name);
208  *field_states_[state_index].get_dual() += reaction;
209  }
210 }
211 
213  std::unordered_map<std::string, const smith::FiniteElementState&> string_to_bc)
214 {
215  for (auto string_bc_pair : string_to_bc) {
216  std::string reaction_name = string_bc_pair.first;
217  const smith::FiniteElementState& reaction_adjoint_load = string_bc_pair.second;
218  SLIC_ERROR_IF(
219  reaction_name_to_reaction_index_.find(reaction_name) == reaction_name_to_reaction_index_.end(),
220  axom::fmt::format("When calling setDualAdjointBcs, could not find reaction named {0} in mesh with tag {1}",
221  reaction_name, mesh_->tag()));
222  size_t reaction_index = reaction_name_to_reaction_index_.at(reaction_name);
223  *reaction_states_[reaction_index].get_dual() += reaction_adjoint_load;
224  }
225 }
226 
227 const FiniteElementState& DifferentiablePhysics::adjoint([[maybe_unused]] const std::string& adjoint_name) const
228 {
229  // MRT, not implemented
230  SLIC_ERROR("What is the use case for asking for the adjoint solution field directly?");
231  return *adjoints_[0];
232 }
233 
235 {
236  if (cycle_ == 0) {
237  field_states_ = initial_field_states_;
238  milestones_.push_back(make_milestone(field_states_, reaction_states_).step());
239  }
240 
241  cycle_prev_ = cycle_;
242  time_prev_ = time_;
243  dt_prev_ = dt;
244 
245  TimeInfo time_info(time_, dt, static_cast<size_t>(cycle_));
246  auto [states, reactions] =
247  advancer_->advanceState(time_info, *field_shape_displacement_, field_states_, field_params_);
248  field_states_ = states;
249  reaction_states_ = reactions;
250 
251  cycle_++;
252  time_ += dt;
253  milestones_.push_back(make_milestone(field_states_, reaction_states_).step());
254 }
255 
257 {
258  --cycle_;
259  const gretl::Int milestone = milestones_[static_cast<size_t>(cycle_)];
260 
261  field_shape_displacement_->clear_dual();
262  for (auto& p : field_params_) {
263  p.clear_dual();
264  }
265 
266  gretl::Int current_step = checkpointer_->currentStep_;
267  while (milestone != current_step) {
268  checkpointer_->reverse_state();
269  current_step = checkpointer_->currentStep_;
270  }
271 
272  gretl::UpstreamStates upstreams(*checkpointer_, checkpointer_->upstreamSteps_[milestone]);
273 
274  const size_t expected_upstreams = field_states_.size() + reaction_states_.size();
275  SLIC_ERROR_IF(expected_upstreams != upstreams.size(), "field/reaction states and upstream sizes do not match.");
276  // recreate the upstream field states with upstream step, field, and dual values.
277  for (size_t s = 0; s < field_states_.size(); ++s) {
278  field_states_[s].reset_step(upstreams[s].step_);
279  field_states_[s].set(upstreams[s].get<FEFieldPtr>());
280  field_states_[s].set_dual(upstreams[s].get_dual<FEDualPtr, FEFieldPtr>());
281  }
282  for (size_t r = 0; r < reaction_states_.size(); ++r) {
283  const size_t upstream_index = field_states_.size() + r;
284  reaction_states_[r].reset_step(upstreams[upstream_index].step_);
285  reaction_states_[r].set(upstreams[upstream_index].get<FEDualPtr>());
286  reaction_states_[r].set_dual(upstreams[upstream_index].get_dual<FEFieldPtr, FEDualPtr>());
287  }
288 }
289 
291 {
292  return *field_params_[parameter_index].get_dual();
293 }
294 
296 {
297  return *field_shape_displacement_->get_dual();
298 }
299 
300 const std::unordered_map<std::string, const smith::FiniteElementDual&>
302 {
303  std::unordered_map<std::string, const smith::FiniteElementDual&> map;
304  for (auto& name : stateNames()) {
305  auto state_index = state_name_to_field_index_.at(name);
306  map.insert({name, *initial_field_states_[state_index].get_dual()});
307  }
308  return map;
309 }
310 
312 {
313  std::vector<FieldState> fields;
314  fields.insert(fields.end(), field_states_.begin(), field_states_.end());
315  fields.insert(fields.end(), field_params_.begin(), field_params_.end());
316  return fields;
317 }
318 
319 FieldState DifferentiablePhysics::getShapeDispFieldState() const { return *field_shape_displacement_; }
320 
321 void DifferentiablePhysics::initializeReactionStates()
322 {
323  reaction_states_.clear();
324  reaction_states_.reserve(reaction_infos_.size());
325  for (const auto& reaction_info : reaction_infos_) {
326  auto reaction = std::make_shared<FiniteElementDual>(*reaction_info.space, reaction_info.name);
327  reaction_states_.push_back(createReactionState(*checkpointer_, reaction));
328  }
329 }
330 
331 } // namespace smith
This is the abstract base class for a generic forward solver.
std::string name_
Name of the physics module.
std::shared_ptr< smith::Mesh > mesh_
The primary mesh.
int cycle_
Current cycle (forward pass time iteration count)
std::string name() const
Return the name of the physics.
virtual double time() const
Get the current forward-solution time.
double time_
Current time for the forward pass.
virtual int cycle() const
Get the current forward-solution cycle iteration number.
std::vector< const smith::FiniteElementState * > adjoints_
List of finite element adjoint states associated with this physics module.
void resetStates(int cycle=0, double time=0.0) override
Base method to reset physics states to the initial time. This does not reset design parameters or sha...
FiniteElementDual loadCheckpointedDual(const std::string &state_name, int cycle) override
Accessor for getting a single named finite element dual solution from the physics modules at a given ...
void reverseAdjointTimestep() override
Reverse one recorded timestep through the gretl graph.
void completeSetup() override
Complete the setup and allocate the necessary data structures.
std::vector< std::string > parameterNames() const override
Get a vector of the finite element state parameter names.
void setShapeDisplacement(const FiniteElementState &shape_displacement) override
Set the current shape displacement for the underlying mesh.
FieldState getShapeDispFieldState() const
Get the tracked shape displacement field.
const FiniteElementState & adjoint(const std::string &adjoint_name) const override
Accessor for getting named finite element state adjoint solution from the physics modules.
void setDualAdjointBcs(std::unordered_map< std::string, const smith::FiniteElementState & > string_to_bc) override
Set the dual loads (dirichlet values) for the adjoint reverse timestep solve This must be called afte...
std::vector< std::string > dualNames() const override
Get a vector of the finite element state dual (reaction) solution names.
FiniteElementDual computeTimestepSensitivity(size_t parameter_index) override
Compute the implicit sensitivity of the quantity of interest used in defining the adjoint load with r...
virtual void resetAdjointStates() override
Base method to reset physics states back to the end of time to start adjoint calculations again....
DifferentiablePhysics(std::shared_ptr< Mesh > mesh, std::shared_ptr< gretl::DataStore > graph, const FieldState &shape_disp, const std::vector< FieldState > &states, const std::vector< FieldState > &params, std::shared_ptr< StateAdvancer > advancer, std::string physics_name, const std::vector< ReactionInfo > &reaction_infos={})
Construct a differentiable physics wrapper around a state advancer and its tracked fields.
void setAdjointLoad(std::unordered_map< std::string, const smith::FiniteElementDual & > string_to_dual) override
Set the loads for the adjoint reverse timestep solve.
const FiniteElementDual & computeTimestepShapeSensitivity() override
Compute the implicit sensitivity of the quantity of interest used in defining the adjoint load with r...
const FiniteElementState & parameter(std::size_t parameter_index) const override
Accessor for getting indexed finite element state parameter fields from the physics modules.
void setState(const std::string &state_name, const FiniteElementState &s) override
Set the primal solution field values of the underlying physics solver.
const FiniteElementState & shapeDisplacement() const override
Accessor for getting the shape displacement field from the physics modules.
FiniteElementState loadCheckpointedState(const std::string &state_name, int cycle) override
Return a state for a stored checkpoint cycle.
std::vector< FieldState > getFieldStatesAndParamStates() const
Get the tracked state fields followed by the tracked parameter fields.
virtual void advanceTimestep(double dt) override
Advance the state variables according to the chosen time integrator.
const FiniteElementState & state(const std::string &state_name) const override
Accessor for getting named finite element state primal solution from the physics modules.
std::vector< std::string > stateNames() const override
Get a vector of the finite element state primal solution names.
const FiniteElementDual & dual(const std::string &dual_name) const override
Accessor for getting named finite element state dual (reaction) solution from the physics modules.
const std::unordered_map< std::string, const smith::FiniteElementDual & > computeInitialConditionSensitivity() const override
Compute the implicit sensitivity of the quantity of interest with respect to the initial condition fi...
void setParameter(const size_t parameter_index, const FiniteElementState &parameter_state) override
Deep copy a parameter field into the internally-owned parameter used for simulations.
Class for encapsulating the dual vector space of a finite element space (i.e. the space of linear for...
Class for encapsulating the critical MFEM components of a primal finite element field.
Defines a BasePhysics implementation backed by FieldState objects and a gretl computational graph.
Smith mesh class which assists in constructing the appropriate parallel mfem meshes and registering a...
Accelerator functionality.
Definition: smith.cpp:36
gretl::State< FEFieldPtr, FEDualPtr > FieldState
typedef
Definition: field_state.hpp:22
gretl::State< int > make_milestone(const std::vector< FieldState > &states, const std::vector< ReactionState > &reactions)
gretl-function to create a dummy-state which records all states and params of interest to the mechani...
ReactionState createReactionState(gretl::DataStore &dataStore, const smith::FEDualPtr &s)
initialize on the gretl::DataStore a ReactionState with values from s
Definition: field_state.hpp:77
mfem::ParFiniteElementSpace & space(FieldState field)
Get the space from the primal field of a field states.
Reaction class which is a names combination of a weak form and a set of dirichlet constrained nodes.
Interface and implementations for advancing from one step to the next. Typically these are time integ...
struct storing time and timestep information
Definition: common.hpp:18
Specifies interface for evaluating weak form residuals and their gradients.