7 #include "smith/differentiable_numerics/multiphysics_time_integrator.hpp"
10 #include "smith/differentiable_numerics/system_solver.hpp"
20 std::vector<std::shared_ptr<SystemBase>> cycle_zero_systems,
21 std::vector<std::shared_ptr<SystemBase>> post_solve_systems)
23 cycle_zero_systems_(std::move(cycle_zero_systems)),
24 post_solve_systems_(std::move(post_solve_systems))
26 for (
size_t i = 0; i < system_->weak_forms.size(); ++i) {
27 const std::string wf_name = system_->weak_forms[i]->name();
28 const std::string reaction_name = system_->field_store->getWeakFormReaction(wf_name);
29 main_unknown_name_to_local_idx_[reaction_name] = i;
35 post_solve_systems_.push_back(std::move(system));
39 const TimeInfo& time_info,
const FieldState& shape_disp,
const std::vector<FieldState>& states,
40 const std::vector<FieldState>& params)
const
42 std::vector<FieldState> current_states = states;
45 system_->field_store->setField(system_->field_store->getShapeDisp().get()->name(), shape_disp);
47 for (
size_t i = 0; i < current_states.size(); ++i) {
48 system_->field_store->setField(i, current_states[i]);
51 SLIC_ERROR_ROOT_IF(params.size() != system_->field_store->getParameterFields().size(),
52 "Parameter size mismatch in advanceState");
53 for (
size_t i = 0; i < params.size(); ++i) {
54 system_->field_store->setField(system_->field_store->getParameterFields()[i].get()->name(), params[i]);
58 const bool requires_cycle_zero_solve =
59 std::any_of(system_->field_store->getTimeIntegrationRules().begin(),
60 system_->field_store->getTimeIntegrationRules().end(), [](
const auto& rule_and_mapping) {
61 return rule_and_mapping.first && rule_and_mapping.first->requiresInitialAccelerationSolve();
64 if (time_info.
cycle() == 0 && !cycle_zero_systems_.empty() && requires_cycle_zero_solve) {
65 for (
const auto& cz_sys : cycle_zero_systems_) {
68 auto cycle_zero_unknowns = cz_sys->solve(cycle_zero_time_info);
70 SLIC_ERROR_ROOT_IF(cycle_zero_unknowns.size() != cz_sys->weak_forms.size(),
71 "Cycle zero system result count does not match number of cycle-zero weak forms");
72 SLIC_ERROR_ROOT_IF(!cz_sys->solve_result_field_names.empty() &&
73 cz_sys->solve_result_field_names.size() != cz_sys->weak_forms.size(),
74 "Cycle zero solve_result_field_names size does not match number of weak forms");
75 for (
size_t i = 0; i < cz_sys->weak_forms.size(); ++i) {
76 const std::string result_field_name =
77 cz_sys->solve_result_field_names.empty()
78 ? system_->field_store->getWeakFormReaction(cz_sys->weak_forms[i]->name())
79 : cz_sys->solve_result_field_names[i];
80 size_t result_field_state_idx = system_->field_store->getFieldIndex(result_field_name);
81 current_states[result_field_state_idx] = cycle_zero_unknowns[i];
82 system_->field_store->setField(result_field_state_idx, cycle_zero_unknowns[i]);
87 std::vector<FieldState> primary_unknowns = system_->solve(time_info);
95 std::vector<FieldState> states_for_reactions = current_states;
96 for (
const auto& [rule, mapping] : system_->field_store->getTimeIntegrationRules()) {
97 auto it = main_unknown_name_to_local_idx_.find(mapping.primary_name);
98 if (it == main_unknown_name_to_local_idx_.end()) {
101 size_t u_idx = system_->field_store->getFieldIndex(mapping.primary_name);
102 FieldState u_new = primary_unknowns[it->second];
103 states_for_reactions[u_idx] = u_new;
107 std::vector<ReactionState> reactions = system_->computeReactions(time_info, states_for_reactions);
111 for (
const auto& [rule, mapping] : system_->field_store->getTimeIntegrationRules()) {
112 auto it = main_unknown_name_to_local_idx_.find(mapping.primary_name);
113 if (it == main_unknown_name_to_local_idx_.end()) {
116 size_t u_idx = system_->field_store->getFieldIndex(mapping.primary_name);
117 system_->field_store->setField(u_idx, primary_unknowns[it->second]);
122 for (
const auto& ps : post_solve_systems_) {
123 auto ps_unknowns = ps->solve(time_info);
124 for (
size_t i = 0; i < ps->weak_forms.size(); ++i) {
125 const std::string reaction_name = ps->field_store->getWeakFormReaction(ps->weak_forms[i]->name());
126 size_t u_idx = ps->field_store->getFieldIndex(reaction_name);
127 ps->field_store->setField(u_idx, ps_unknowns[i]);
134 std::vector<FieldState> new_states = system_->field_store->getAllFields();
136 for (
const auto& [rule, mapping] : system_->field_store->getTimeIntegrationRules()) {
137 auto it = main_unknown_name_to_local_idx_.find(mapping.primary_name);
138 if (it == main_unknown_name_to_local_idx_.end()) {
141 size_t u_idx = system_->field_store->getFieldIndex(mapping.primary_name);
142 FieldState u_new = primary_unknowns[it->second];
143 new_states[u_idx] = u_new;
145 std::vector<FieldState> rule_inputs;
146 rule_inputs.push_back(u_new);
147 if (rule->num_args() >= 2) {
148 rule_inputs.push_back(current_states[u_idx]);
151 if (rule->num_args() >= 3 && !mapping.dot_name.empty()) {
152 size_t v_idx = system_->field_store->getFieldIndex(mapping.dot_name);
153 rule_inputs.push_back(current_states[v_idx]);
156 if (rule->num_args() >= 4 && !mapping.ddot_name.empty()) {
157 size_t a_idx = system_->field_store->getFieldIndex(mapping.ddot_name);
158 rule_inputs.push_back(current_states[a_idx]);
161 if (!mapping.dot_name.empty()) {
162 size_t v_idx = system_->field_store->getFieldIndex(mapping.dot_name);
163 new_states[v_idx] = rule->corrected_dot(time_info, rule_inputs);
166 if (!mapping.ddot_name.empty()) {
167 size_t a_idx = system_->field_store->getFieldIndex(mapping.ddot_name);
168 new_states[a_idx] = rule->corrected_ddot(time_info, rule_inputs);
171 if (!mapping.history_name.empty()) {
172 size_t hist_idx = system_->field_store->getFieldIndex(mapping.history_name);
173 new_states[hist_idx] = u_new;
180 for (
const auto& [rule, mapping] : system_->field_store->getTimeIntegrationRules()) {
181 if (main_unknown_name_to_local_idx_.count(mapping.primary_name)) {
184 if (!mapping.history_name.empty()) {
185 size_t primary_idx = system_->field_store->getFieldIndex(mapping.primary_name);
186 size_t hist_idx = system_->field_store->getFieldIndex(mapping.history_name);
187 new_states[hist_idx] = new_states[primary_idx];
191 for (
size_t i = 0; i < new_states.size(); ++i) {
192 system_->field_store->setField(i, new_states[i]);
195 return {new_states, reactions};
MultiphysicsTimeIntegrator(std::shared_ptr< SystemBase > system, std::vector< std::shared_ptr< SystemBase >> cycle_zero_systems={}, std::vector< std::shared_ptr< SystemBase >> post_solve_systems={})
Construct a multiphysics advancer around main and auxiliary systems.
std::pair< std::vector< FieldState >, std::vector< ReactionState > > advanceState(const TimeInfo &time_info, const FieldState &shape_disp, const std::vector< FieldState > &states, const std::vector< FieldState > ¶ms) const override
Advance the multiphysics state by one time step.
void addPostSolveSystem(std::shared_ptr< SystemBase > system)
Register a system to be solved after the main solve and reaction computation.
Contains DirichletBoundaryConditions class for interaction with the differentiable solve interfaces.
Accelerator functionality.
gretl::State< FEFieldPtr, FEDualPtr > FieldState
typedef
This file contains nonlinear block solver interfaces and helpers.
Methods for solving systems of equations as given by WeakForms. Tracks these operations on the gretl ...
Reaction class which is a names combination of a weak form and a set of dirichlet constrained nodes.
struct storing time and timestep information
double dt() const
accessor for dt
size_t cycle() const
accessor for cycle
@ CycleZero
Initialization or cycle zero step.
double time() const
accessor for the current time