Smith  0.1
Smith is an implicit thermal structural mechanics simulation code.
multiphysics_time_integrator.cpp
1 // Copyright (c) Lawrence Livermore National Security, LLC and
2 // other Smith Project Developers. See the top-level LICENSE file for
3 // details.
4 //
5 // SPDX-License-Identifier: (BSD-3-Clause)
6 
7 #include "smith/differentiable_numerics/multiphysics_time_integrator.hpp"
10 #include "smith/differentiable_numerics/system_solver.hpp"
13 
14 #include <algorithm>
15 #include <stdexcept>
16 
17 namespace smith {
18 
19 MultiphysicsTimeIntegrator::MultiphysicsTimeIntegrator(std::shared_ptr<SystemBase> system,
20  std::vector<std::shared_ptr<SystemBase>> cycle_zero_systems,
21  std::vector<std::shared_ptr<SystemBase>> post_solve_systems)
22  : system_(system),
23  cycle_zero_systems_(std::move(cycle_zero_systems)),
24  post_solve_systems_(std::move(post_solve_systems))
25 {
26  for (size_t i = 0; i < system_->weak_forms.size(); ++i) {
27  const std::string wf_name = system_->weak_forms[i]->name();
28  const std::string reaction_name = system_->field_store->getWeakFormReaction(wf_name);
29  main_unknown_name_to_local_idx_[reaction_name] = i;
30  }
31 }
32 
33 void MultiphysicsTimeIntegrator::addPostSolveSystem(std::shared_ptr<SystemBase> system)
34 {
35  post_solve_systems_.push_back(std::move(system));
36 }
37 
38 std::pair<std::vector<FieldState>, std::vector<ReactionState>> MultiphysicsTimeIntegrator::advanceState(
39  const TimeInfo& time_info, const FieldState& shape_disp, const std::vector<FieldState>& states,
40  const std::vector<FieldState>& params) const
41 {
42  std::vector<FieldState> current_states = states;
43 
44  // Sync FieldStore with (possibly updated) states and params so they are current for solve
45  system_->field_store->setField(system_->field_store->getShapeDisp().get()->name(), shape_disp);
46 
47  for (size_t i = 0; i < current_states.size(); ++i) {
48  system_->field_store->setField(i, current_states[i]);
49  }
50  // Optional: update parameter fields as well? (assuming they are aligned)
51  SLIC_ERROR_ROOT_IF(params.size() != system_->field_store->getParameterFields().size(),
52  "Parameter size mismatch in advanceState");
53  for (size_t i = 0; i < params.size(); ++i) {
54  system_->field_store->setField(system_->field_store->getParameterFields()[i].get()->name(), params[i]);
55  }
56 
57  // Handle initial acceleration solve at cycle 0
58  const bool requires_cycle_zero_solve =
59  std::any_of(system_->field_store->getTimeIntegrationRules().begin(),
60  system_->field_store->getTimeIntegrationRules().end(), [](const auto& rule_and_mapping) {
61  return rule_and_mapping.first && rule_and_mapping.first->requiresInitialAccelerationSolve();
62  });
63 
64  if (time_info.cycle() == 0 && !cycle_zero_systems_.empty() && requires_cycle_zero_solve) {
65  for (const auto& cz_sys : cycle_zero_systems_) {
66  TimeInfo cycle_zero_time_info(time_info.time() - time_info.dt(), time_info.dt(), time_info.cycle(),
68  auto cycle_zero_unknowns = cz_sys->solve(cycle_zero_time_info);
69 
70  SLIC_ERROR_ROOT_IF(cycle_zero_unknowns.size() != cz_sys->weak_forms.size(),
71  "Cycle zero system result count does not match number of cycle-zero weak forms");
72  SLIC_ERROR_ROOT_IF(!cz_sys->solve_result_field_names.empty() &&
73  cz_sys->solve_result_field_names.size() != cz_sys->weak_forms.size(),
74  "Cycle zero solve_result_field_names size does not match number of weak forms");
75  for (size_t i = 0; i < cz_sys->weak_forms.size(); ++i) {
76  const std::string result_field_name =
77  cz_sys->solve_result_field_names.empty()
78  ? system_->field_store->getWeakFormReaction(cz_sys->weak_forms[i]->name())
79  : cz_sys->solve_result_field_names[i];
80  size_t result_field_state_idx = system_->field_store->getFieldIndex(result_field_name);
81  current_states[result_field_state_idx] = cycle_zero_unknowns[i];
82  system_->field_store->setField(result_field_state_idx, cycle_zero_unknowns[i]);
83  }
84  }
85  }
86 
87  std::vector<FieldState> primary_unknowns = system_->solve(time_info);
88 
89  // Build a map from the main system's unknown names to their position in primary_unknowns.
90  // Entries in the shared FieldStore's time integration rules that belong to post-solve
91  // subsystems (e.g. stress projection) are NOT present here and must be skipped by downstream
92  // lookups that walk getTimeIntegrationRules().
93 
94  // Create states for reaction computation: newly solved primary unknowns + current states
95  std::vector<FieldState> states_for_reactions = current_states;
96  for (const auto& [rule, mapping] : system_->field_store->getTimeIntegrationRules()) {
97  auto it = main_unknown_name_to_local_idx_.find(mapping.primary_name);
98  if (it == main_unknown_name_to_local_idx_.end()) {
99  continue; // rule belongs to a post-solve subsystem, not the main solve
100  }
101  size_t u_idx = system_->field_store->getFieldIndex(mapping.primary_name);
102  FieldState u_new = primary_unknowns[it->second];
103  states_for_reactions[u_idx] = u_new;
104  }
105 
106  // Compute reactions using newly solved unknowns but BEFORE time integration state updates
107  std::vector<ReactionState> reactions = system_->computeReactions(time_info, states_for_reactions);
108 
109  // Sync field_store with newly solved primary unknowns so post-solve systems (e.g. stress
110  // projection) read the current displacement rather than the pre-solve snapshot.
111  for (const auto& [rule, mapping] : system_->field_store->getTimeIntegrationRules()) {
112  auto it = main_unknown_name_to_local_idx_.find(mapping.primary_name);
113  if (it == main_unknown_name_to_local_idx_.end()) {
114  continue;
115  }
116  size_t u_idx = system_->field_store->getFieldIndex(mapping.primary_name);
117  system_->field_store->setField(u_idx, primary_unknowns[it->second]);
118  }
119 
120  // Solve post-solve systems (e.g. stress projection for output) and sync their results back
121  // into the shared field_store so getAllFields() returns the updated values for new_states.
122  for (const auto& ps : post_solve_systems_) {
123  auto ps_unknowns = ps->solve(time_info);
124  for (size_t i = 0; i < ps->weak_forms.size(); ++i) {
125  const std::string reaction_name = ps->field_store->getWeakFormReaction(ps->weak_forms[i]->name());
126  size_t u_idx = ps->field_store->getFieldIndex(reaction_name);
127  ps->field_store->setField(u_idx, ps_unknowns[i]);
128  }
129  }
130 
131  // Now do time integration to compute corrected velocities/accelerations and update all states.
132  // Seed new_states from field_store, which already reflects post-solve subsystem updates and
133  // the freshly synced primary unknowns.
134  std::vector<FieldState> new_states = system_->field_store->getAllFields();
135 
136  for (const auto& [rule, mapping] : system_->field_store->getTimeIntegrationRules()) {
137  auto it = main_unknown_name_to_local_idx_.find(mapping.primary_name);
138  if (it == main_unknown_name_to_local_idx_.end()) {
139  continue; // rule belongs to a post-solve subsystem, not the main solve
140  }
141  size_t u_idx = system_->field_store->getFieldIndex(mapping.primary_name);
142  FieldState u_new = primary_unknowns[it->second];
143  new_states[u_idx] = u_new;
144 
145  std::vector<FieldState> rule_inputs;
146  rule_inputs.push_back(u_new); // u_{n+1}
147  if (rule->num_args() >= 2) {
148  rule_inputs.push_back(current_states[u_idx]); // u_n
149  }
150 
151  if (rule->num_args() >= 3 && !mapping.dot_name.empty()) {
152  size_t v_idx = system_->field_store->getFieldIndex(mapping.dot_name);
153  rule_inputs.push_back(current_states[v_idx]);
154  }
155 
156  if (rule->num_args() >= 4 && !mapping.ddot_name.empty()) {
157  size_t a_idx = system_->field_store->getFieldIndex(mapping.ddot_name);
158  rule_inputs.push_back(current_states[a_idx]);
159  }
160 
161  if (!mapping.dot_name.empty()) {
162  size_t v_idx = system_->field_store->getFieldIndex(mapping.dot_name);
163  new_states[v_idx] = rule->corrected_dot(time_info, rule_inputs);
164  }
165 
166  if (!mapping.ddot_name.empty()) {
167  size_t a_idx = system_->field_store->getFieldIndex(mapping.ddot_name);
168  new_states[a_idx] = rule->corrected_ddot(time_info, rule_inputs);
169  }
170 
171  if (!mapping.history_name.empty()) {
172  size_t hist_idx = system_->field_store->getFieldIndex(mapping.history_name);
173  new_states[hist_idx] = u_new;
174  }
175  }
176 
177  // Copy solve-state → history for post-solve fields when a public history field exists.
178  // The main loop skipped these rules; their primary fields are already correct in new_states
179  // (populated from all_current_states above), so only the history field needs updating.
180  for (const auto& [rule, mapping] : system_->field_store->getTimeIntegrationRules()) {
181  if (main_unknown_name_to_local_idx_.count(mapping.primary_name)) {
182  continue; // already handled by main time integration loop above
183  }
184  if (!mapping.history_name.empty()) {
185  size_t primary_idx = system_->field_store->getFieldIndex(mapping.primary_name);
186  size_t hist_idx = system_->field_store->getFieldIndex(mapping.history_name);
187  new_states[hist_idx] = new_states[primary_idx];
188  }
189  }
190 
191  for (size_t i = 0; i < new_states.size(); ++i) {
192  system_->field_store->setField(i, new_states[i]);
193  }
194 
195  return {new_states, reactions};
196 }
197 
198 } // namespace smith
MultiphysicsTimeIntegrator(std::shared_ptr< SystemBase > system, std::vector< std::shared_ptr< SystemBase >> cycle_zero_systems={}, std::vector< std::shared_ptr< SystemBase >> post_solve_systems={})
Construct a multiphysics advancer around main and auxiliary systems.
std::pair< std::vector< FieldState >, std::vector< ReactionState > > advanceState(const TimeInfo &time_info, const FieldState &shape_disp, const std::vector< FieldState > &states, const std::vector< FieldState > &params) const override
Advance the multiphysics state by one time step.
void addPostSolveSystem(std::shared_ptr< SystemBase > system)
Register a system to be solved after the main solve and reaction computation.
Contains DirichletBoundaryConditions class for interaction with the differentiable solve interfaces.
Accelerator functionality.
Definition: smith.cpp:36
gretl::State< FEFieldPtr, FEDualPtr > FieldState
typedef
Definition: field_state.hpp:22
This file contains nonlinear block solver interfaces and helpers.
Methods for solving systems of equations as given by WeakForms. Tracks these operations on the gretl ...
Reaction class which is a names combination of a weak form and a set of dirichlet constrained nodes.
struct storing time and timestep information
Definition: common.hpp:18
double dt() const
accessor for dt
Definition: common.hpp:36
size_t cycle() const
accessor for cycle
Definition: common.hpp:39
@ CycleZero
Initialization or cycle zero step.
double time() const
accessor for the current time
Definition: common.hpp:33