Smith  0.1
Smith is an implicit thermal structural mechanics simulation code.
differentiable_physics.cpp
1 // Copyright (c) Lawrence Livermore National Security, LLC and
2 // other Smith Project Developers. See the top-level LICENSE file for
3 // details.
4 //
5 // SPDX-License-Identifier: (BSD-3-Clause)
6 
7 #include "gretl/data_store.hpp"
10 #include "smith/physics/mesh.hpp"
13 #include "gretl/upstream_state.hpp"
14 
15 namespace smith {
16 
20 gretl::State<int> make_milestone(const std::vector<FieldState>& states, const std::vector<ReactionState>& reactions)
21 {
22  std::vector<gretl::StateBase> base_states;
23  for (const auto& s : states) {
24  base_states.push_back(s);
25  }
26  for (const auto& r : reactions) {
27  base_states.push_back(r);
28  }
29 
30  auto milestone = states[0].create_state<int, int>(base_states);
31 
32  milestone.set_eval(
33  []([[maybe_unused]] const gretl::UpstreamStates& inputs, gretl::DownstreamState& output) { output.set<int>(0); });
34  milestone.set_vjp(
35  []([[maybe_unused]] gretl::UpstreamStates& inputs, [[maybe_unused]] const gretl::DownstreamState& output) {});
36 
37  return milestone.finalize();
38 }
39 
40 // mesh, equation, fields, parameters, state advancer, solver
41 DifferentiablePhysics::DifferentiablePhysics(std::shared_ptr<Mesh> mesh, std::shared_ptr<gretl::DataStore> graph,
42  const FieldState& shape_disp, const std::vector<FieldState>& states,
43  const std::vector<FieldState>& params,
44  std::shared_ptr<StateAdvancer> advancer, std::string mech_name,
45  const std::vector<ReactionInfo>& reaction_infos)
46  : BasePhysics(mech_name, mesh, 0, 0.0, false), // the false is checkpoint_to_disk
47  checkpointer_(graph),
48  advancer_(advancer),
49  reaction_infos_(reaction_infos)
50 {
51  SLIC_ERROR_IF(states.size() == 0, "Must have a least 1 state for a mechanics.");
52  field_shape_displacement_ = std::make_unique<FieldState>(shape_disp);
53  for (size_t i = 0; i < states.size(); ++i) {
54  const auto& s = states[i];
55  field_states_.push_back(s);
56  initial_field_states_.push_back(s);
57  state_name_to_field_index_[s.get()->name()] = i;
58  state_names_.push_back(s.get()->name());
59  }
60 
61  for (size_t i = 0; i < params.size(); ++i) {
62  const auto& p = params[i];
63  field_params_.push_back(p);
64  param_name_to_field_index_[p.get()->name()] = i;
65  param_names_.push_back(p.get()->name());
66  }
67 
68  reaction_names_.reserve(reaction_infos_.size());
69  for (size_t i = 0; i < reaction_infos_.size(); ++i) {
70  SLIC_ERROR_IF(
71  reaction_infos_[i].space == nullptr,
72  axom::fmt::format("Dual '{}' in physics module '{}' has null FE space.", reaction_infos_[i].name, name_));
73  reaction_names_.push_back(reaction_infos_[i].name);
74  reaction_name_to_reaction_index_[reaction_infos_[i].name] = i;
75  }
76 
77  completeSetup();
78 }
79 
81 {
82  SLIC_ERROR_IF(field_states_.empty(), "Empty field state during completeSetup()");
83  initializeReactionStates();
84 }
85 
86 void DifferentiablePhysics::resetStates(int cycle, double time)
87 {
88  for (size_t i = 0; i < initial_field_states_.size(); ++i) {
89  field_states_[i] = initial_field_states_[i];
90  }
91  milestones_.clear();
92  checkpointer_->reset_graph();
93  initializeReactionStates();
94  time_ = time;
95  cycle_ = cycle;
96 }
97 
99 {
100  checkpointer_->finalize_graph();
101  checkpointer_->reset_for_backprop();
102  gretl_assert(checkpointer_->check_validity());
103 }
104 
105 std::vector<std::string> DifferentiablePhysics::stateNames() const { return state_names_; }
106 
107 std::vector<std::string> DifferentiablePhysics::parameterNames() const { return param_names_; }
108 
109 std::vector<std::string> DifferentiablePhysics::dualNames() const { return reaction_names_; }
110 
111 const FiniteElementState& DifferentiablePhysics::state([[maybe_unused]] const std::string& field_name) const
112 {
113  SLIC_ERROR_IF(
114  state_name_to_field_index_.find(field_name) == state_name_to_field_index_.end(),
115  std::format("Could not find field named {0} in mesh with tag \"{1}\" to get", field_name, mesh_->tag()));
116  size_t state_index = state_name_to_field_index_.at(field_name);
117  return *field_states_[state_index].get();
118 }
119 
120 const FiniteElementDual& DifferentiablePhysics::dual(const std::string& reaction_name) const
121 {
122  if (reaction_name_to_reaction_index_.find(reaction_name) == reaction_name_to_reaction_index_.end()) {
123  std::string available;
124  for (auto& n : reaction_names_) available += n + " ";
125  SLIC_ERROR(axom::fmt::format(
126  "Could not find reaction named {0} in mesh with tag \"{1}\" to get. Available reactions (size {2}): {3}",
127  reaction_name, mesh_->tag(), reaction_names_.size(), available));
128  }
129  size_t reaction_index = reaction_name_to_reaction_index_.at(reaction_name);
130 
131  SLIC_ERROR_IF(reaction_states_.empty() && !reaction_names_.empty(),
132  "Reactions were not computed during advanceState, but were requested.");
133 
134  SLIC_ERROR_IF(
135  reaction_index >= reaction_states_.size(),
136  "Reaction reactions not correctly allocated yet, cannot get reaction until after initializationStep is called.");
137 
138  return *reaction_states_[reaction_index].get();
139 }
140 
141 FiniteElementDual DifferentiablePhysics::loadCheckpointedDual(const std::string& reaction_name, int cycle)
142 {
143  SLIC_ERROR_IF(
144  cycle != cycle_,
145  axom::fmt::format("Due to checkpointing restrictions in smith::DifferentiablePhysics, cannot ask for "
146  "an arbitrary checkpointed reaction cycle, asking for cycle {}, but physics is at cycle {}",
147  cycle, cycle_));
148  return dual(reaction_name);
149 }
150 
151 FiniteElementState DifferentiablePhysics::loadCheckpointedState(const std::string& state_name, int cycle)
152 {
153  SLIC_ERROR_IF(cycle != cycle_,
154  std::format("Due to checkpointing restrictions in smith::Mechanics, cannot ask for an arbitrary "
155  "checkpointed cycle, asking for cycle {}, but physics is at cycle {}",
156  cycle, cycle_));
157  return state(state_name);
158 }
159 
160 const FiniteElementState& DifferentiablePhysics::shapeDisplacement() const { return *field_shape_displacement_->get(); }
161 
162 const FiniteElementState& DifferentiablePhysics::parameter(std::size_t parameter_index) const
163 {
164  SLIC_ERROR_IF(parameter_index >= field_params_.size(),
165  std::format("Parameter index {} requested, but only {} parameters exist in physics module {}.",
166  parameter_index, field_params_.size(), name_));
167  return *field_params_[parameter_index].get();
168 }
169 
170 const FiniteElementState& DifferentiablePhysics::parameter(const std::string& parameter_name) const
171 {
172  SLIC_ERROR_IF(
173  param_name_to_field_index_.find(parameter_name) == param_name_to_field_index_.end(),
174  std::format("Could not find parameter named {0} in mesh with tag \"{1}\" to get", parameter_name, mesh_->tag()));
175  size_t param_index = param_name_to_field_index_.at(parameter_name);
176  return parameter(param_index);
177 }
178 
179 void DifferentiablePhysics::setParameter(const size_t parameter_index, const FiniteElementState& parameter_state)
180 {
181  SLIC_ERROR_IF(parameter_index >= field_params_.size(),
182  std::format("Parameter '{}' requested when only '{}' parameters exist in physics module '{}'",
183  parameter_index, field_params_.size(), name_));
184  *field_params_[parameter_index].get() = parameter_state;
185 }
186 
188 {
189  *field_shape_displacement_->get() = shape_displacement;
190 }
191 
192 void DifferentiablePhysics::setState([[maybe_unused]] const std::string& field_name,
193  [[maybe_unused]] const FiniteElementState& s)
194 {
195  SLIC_ERROR_IF(state_name_to_field_index_.find(field_name) == state_name_to_field_index_.end(),
196  std::format("Could not find field named {0} in mesh with tag {1} to set", field_name, mesh_->tag()));
197  size_t state_index = state_name_to_field_index_.at(field_name);
198  *field_states_[state_index].get() = s;
199  *initial_field_states_[state_index].get() = s;
200 }
201 
203  std::unordered_map<std::string, const smith::FiniteElementDual&> string_to_reaction)
204 {
205  for (auto string_reaction_pair : string_to_reaction) {
206  std::string field_name = string_reaction_pair.first;
207  const smith::FiniteElementDual& reaction = string_reaction_pair.second;
208  SLIC_ERROR_IF(
209  state_name_to_field_index_.find(field_name) == state_name_to_field_index_.end(),
210  axom::fmt::format("Could not find reaction named {0} in mesh with tag {1}", field_name, mesh_->tag()));
211  size_t state_index = state_name_to_field_index_.at(field_name);
212  *field_states_[state_index].get_dual() += reaction;
213  }
214 }
215 
217  std::unordered_map<std::string, const smith::FiniteElementState&> string_to_bc)
218 {
219  for (auto string_bc_pair : string_to_bc) {
220  std::string reaction_name = string_bc_pair.first;
221  const smith::FiniteElementState& reaction_adjoint_load = string_bc_pair.second;
222  SLIC_ERROR_IF(
223  reaction_name_to_reaction_index_.find(reaction_name) == reaction_name_to_reaction_index_.end(),
224  axom::fmt::format("When calling setDualAdjointBcs, could not find reaction named {0} in mesh with tag {1}",
225  reaction_name, mesh_->tag()));
226  size_t reaction_index = reaction_name_to_reaction_index_.at(reaction_name);
227  *reaction_states_[reaction_index].get_dual() += reaction_adjoint_load;
228  }
229 }
230 
231 const FiniteElementState& DifferentiablePhysics::adjoint([[maybe_unused]] const std::string& adjoint_name) const
232 {
233  // MRT, not implemented
234  SLIC_ERROR("What is the use case for asking for the adjoint solution field directly?");
235  return *adjoints_[0];
236 }
237 
239 {
240  if (cycle_ == 0) {
241  field_states_ = initial_field_states_;
242  milestones_.push_back(make_milestone(field_states_, reaction_states_).step());
243  }
244 
245  cycle_prev_ = cycle_;
246  time_prev_ = time_;
247  dt_prev_ = dt;
248 
249  TimeInfo time_info(time_, dt, static_cast<size_t>(cycle_));
250  auto [states, reactions] =
251  advancer_->advanceState(time_info, *field_shape_displacement_, field_states_, field_params_);
252  field_states_ = states;
253  reaction_states_ = reactions;
254 
255  cycle_++;
256  time_ += dt;
257  milestones_.push_back(make_milestone(field_states_, reaction_states_).step());
258 }
259 
261 {
262  --cycle_;
263  const gretl::Int milestone = milestones_[static_cast<size_t>(cycle_)];
264 
265  field_shape_displacement_->clear_dual();
266  for (auto& p : field_params_) {
267  p.clear_dual();
268  }
269 
270  gretl::Int current_step = checkpointer_->currentStep_;
271  while (milestone != current_step) {
272  checkpointer_->reverse_state();
273  current_step = checkpointer_->currentStep_;
274  }
275 
276  gretl::UpstreamStates upstreams(*checkpointer_, checkpointer_->upstreamSteps_[milestone]);
277 
278  const size_t expected_upstreams = field_states_.size() + reaction_states_.size();
279  SLIC_ERROR_IF(expected_upstreams != upstreams.size(), "field/reaction states and upstream sizes do not match.");
280  // recreate the upstream field states with upstream step, field, and dual values.
281  for (size_t s = 0; s < field_states_.size(); ++s) {
282  field_states_[s].reset_step(upstreams[s].step_);
283  field_states_[s].set(upstreams[s].get<FEFieldPtr>());
284  field_states_[s].set_dual(upstreams[s].get_dual<FEDualPtr, FEFieldPtr>());
285  }
286  for (size_t r = 0; r < reaction_states_.size(); ++r) {
287  const size_t upstream_index = field_states_.size() + r;
288  reaction_states_[r].reset_step(upstreams[upstream_index].step_);
289  reaction_states_[r].set(upstreams[upstream_index].get<FEDualPtr>());
290  reaction_states_[r].set_dual(upstreams[upstream_index].get_dual<FEFieldPtr, FEDualPtr>());
291  }
292 }
293 
295 {
296  return *field_params_[parameter_index].get_dual();
297 }
298 
300 {
301  return *field_shape_displacement_->get_dual();
302 }
303 
304 const std::unordered_map<std::string, const smith::FiniteElementDual&>
306 {
307  std::unordered_map<std::string, const smith::FiniteElementDual&> map;
308  for (auto& name : stateNames()) {
309  auto state_index = state_name_to_field_index_.at(name);
310  map.insert({name, *initial_field_states_[state_index].get_dual()});
311  }
312  return map;
313 }
314 
316 {
317  std::vector<FieldState> fields;
318  fields.insert(fields.end(), field_states_.begin(), field_states_.end());
319  fields.insert(fields.end(), field_params_.begin(), field_params_.end());
320  return fields;
321 }
322 
323 FieldState DifferentiablePhysics::getInitialFieldState(const std::string& state_name) const
324 {
325  SLIC_ERROR_IF(state_name_to_field_index_.find(state_name) == state_name_to_field_index_.end(),
326  std::format("Could not find initial field named {0}", state_name));
327  return initial_field_states_[state_name_to_field_index_.at(state_name)];
328 }
329 
330 FieldState DifferentiablePhysics::getFieldState(const std::string& state_name) const
331 {
332  SLIC_ERROR_IF(state_name_to_field_index_.find(state_name) == state_name_to_field_index_.end(),
333  std::format("Could not find field named {0}", state_name));
334  return field_states_[state_name_to_field_index_.at(state_name)];
335 }
336 
337 FieldState DifferentiablePhysics::getFieldParam(const std::string& param_name) const
338 {
339  SLIC_ERROR_IF(param_name_to_field_index_.find(param_name) == param_name_to_field_index_.end(),
340  std::format("Could not find parameter named {0}", param_name));
341  return field_params_[param_name_to_field_index_.at(param_name)];
342 }
343 
344 FieldState DifferentiablePhysics::getShapeDispFieldState() const { return *field_shape_displacement_; }
345 
346 void DifferentiablePhysics::initializeReactionStates()
347 {
348  reaction_states_.clear();
349  reaction_states_.reserve(reaction_infos_.size());
350  for (const auto& reaction_info : reaction_infos_) {
351  auto reaction = std::make_shared<FiniteElementDual>(*reaction_info.space, reaction_info.name);
352  reaction_states_.push_back(createReactionState(*checkpointer_, reaction));
353  }
354 }
355 
356 } // namespace smith
This is the abstract base class for a generic forward solver.
std::string name_
Name of the physics module.
std::shared_ptr< smith::Mesh > mesh_
The primary mesh.
int cycle_
Current cycle (forward pass time iteration count)
std::string name() const
Return the name of the physics.
virtual double time() const
Get the current forward-solution time.
double time_
Current time for the forward pass.
virtual int cycle() const
Get the current forward-solution cycle iteration number.
std::vector< const smith::FiniteElementState * > adjoints_
List of finite element adjoint states associated with this physics module.
void resetStates(int cycle=0, double time=0.0) override
Base method to reset physics states to the initial time. This does not reset design parameters or sha...
FiniteElementDual loadCheckpointedDual(const std::string &state_name, int cycle) override
Accessor for getting a single named finite element dual solution from the physics modules at a given ...
FieldState getFieldParam(const std::string &param_name) const
Get a tracked parameter field by name.
void reverseAdjointTimestep() override
Reverse one recorded timestep through the gretl graph.
void completeSetup() override
Complete the setup and allocate the necessary data structures.
std::vector< std::string > parameterNames() const override
Get a vector of the finite element state parameter names.
FieldState getFieldState(const std::string &state_name) const
Get a tracked current state field by name.
void setShapeDisplacement(const FiniteElementState &shape_displacement) override
Set the current shape displacement for the underlying mesh.
FieldState getShapeDispFieldState() const
Get the tracked shape displacement field.
const FiniteElementState & adjoint(const std::string &adjoint_name) const override
Accessor for getting named finite element state adjoint solution from the physics modules.
void setDualAdjointBcs(std::unordered_map< std::string, const smith::FiniteElementState & > string_to_bc) override
Set the dual loads (dirichlet values) for the adjoint reverse timestep solve This must be called afte...
std::vector< std::string > dualNames() const override
Get a vector of the finite element state dual (reaction) solution names.
FiniteElementDual computeTimestepSensitivity(size_t parameter_index) override
Compute the implicit sensitivity of the quantity of interest used in defining the adjoint load with r...
virtual void resetAdjointStates() override
Base method to reset physics states back to the end of time to start adjoint calculations again....
DifferentiablePhysics(std::shared_ptr< Mesh > mesh, std::shared_ptr< gretl::DataStore > graph, const FieldState &shape_disp, const std::vector< FieldState > &states, const std::vector< FieldState > &params, std::shared_ptr< StateAdvancer > advancer, std::string physics_name, const std::vector< ReactionInfo > &reaction_infos={})
Construct a differentiable physics wrapper around a state advancer and its tracked fields.
void setAdjointLoad(std::unordered_map< std::string, const smith::FiniteElementDual & > string_to_dual) override
Set the loads for the adjoint reverse timestep solve.
const FiniteElementDual & computeTimestepShapeSensitivity() override
Compute the implicit sensitivity of the quantity of interest used in defining the adjoint load with r...
const FiniteElementState & parameter(std::size_t parameter_index) const override
Accessor for getting indexed finite element state parameter fields from the physics modules.
void setState(const std::string &state_name, const FiniteElementState &s) override
Set the primal solution field values of the underlying physics solver.
FieldState getInitialFieldState(const std::string &state_name) const
Get a tracked initial state field by name.
const FiniteElementState & shapeDisplacement() const override
Accessor for getting the shape displacement field from the physics modules.
FiniteElementState loadCheckpointedState(const std::string &state_name, int cycle) override
Return a state for a stored checkpoint cycle.
std::vector< FieldState > getFieldStatesAndParamStates() const
Get the tracked state fields followed by the tracked parameter fields.
virtual void advanceTimestep(double dt) override
Advance the state variables according to the chosen time integrator.
const FiniteElementState & state(const std::string &state_name) const override
Accessor for getting named finite element state primal solution from the physics modules.
std::vector< std::string > stateNames() const override
Get a vector of the finite element state primal solution names.
const FiniteElementDual & dual(const std::string &dual_name) const override
Accessor for getting named finite element state dual (reaction) solution from the physics modules.
const std::unordered_map< std::string, const smith::FiniteElementDual & > computeInitialConditionSensitivity() const override
Compute the implicit sensitivity of the quantity of interest with respect to the initial condition fi...
void setParameter(const size_t parameter_index, const FiniteElementState &parameter_state) override
Deep copy a parameter field into the internally-owned parameter used for simulations.
Class for encapsulating the dual vector space of a finite element space (i.e. the space of linear for...
Class for encapsulating the critical MFEM components of a primal finite element field.
Defines a BasePhysics implementation backed by FieldState objects and a gretl computational graph.
Smith mesh class which assists in constructing the appropriate parallel mfem meshes and registering a...
Accelerator functionality.
Definition: smith.cpp:36
gretl::State< FEFieldPtr, FEDualPtr > FieldState
typedef
Definition: field_state.hpp:22
gretl::State< int > make_milestone(const std::vector< FieldState > &states, const std::vector< ReactionState > &reactions)
gretl-function to create a dummy-state which records all states and params of interest to the mechani...
ReactionState createReactionState(gretl::DataStore &dataStore, const smith::FEDualPtr &s)
initialize on the gretl::DataStore a ReactionState with values from s
Definition: field_state.hpp:77
mfem::ParFiniteElementSpace & space(FieldState field)
Get the space from the primal field of a field states.
Reaction class which is a names combination of a weak form and a set of dirichlet constrained nodes.
Interface and implementations for advancing from one step to the next. Typically these are time integ...
struct storing time and timestep information
Definition: common.hpp:18
Specifies interface for evaluating weak form residuals and their gradients.