7 #include "smith/differentiable_numerics/system_solver.hpp"
16 #include <unordered_map>
18 #include <axom/slic.hpp>
19 #include <axom/fmt.hpp>
24 : max_staggered_iterations_(1), exact_staggered_steps_(false)
30 : max_staggered_iterations_(max_staggered_iterations), exact_staggered_steps_(exact_staggered_steps)
32 SLIC_ERROR_IF(max_staggered_iterations <= 0, "max_staggered_iterations must be > 0
");
35 void SystemSolver::addSubsystemSolver(const std::vector<size_t>& block_indices,
36 std::shared_ptr<NonlinearBlockSolverBase> solver, double relaxation_factor)
38 SLIC_ERROR_IF(!solver, "SystemSolver stage solver must be non-
null");
39 SLIC_ERROR_IF(relaxation_factor <= 0.0 || relaxation_factor > 1.0,
40 axom::fmt::format("Stage relaxation_factor {} must be in (0, 1]
", relaxation_factor));
42 stages_.push_back(Stage{block_indices, std::move(solver), relaxation_factor});
45 void SystemSolver::appendStagesWithBlockMapping(const SystemSolver& subsystem_solver,
46 const std::vector<size_t>& global_block_indices)
48 SLIC_ERROR_IF(global_block_indices.empty(), "Global block index map must be non-empty
");
50 for (const auto& stage : subsystem_solver.stages_) {
51 std::vector<size_t> remapped_block_indices;
52 if (stage.block_indices.empty()) {
53 remapped_block_indices = global_block_indices;
55 remapped_block_indices.reserve(stage.block_indices.size());
56 for (size_t local_block_index : stage.block_indices) {
57 SLIC_ERROR_IF(local_block_index >= global_block_indices.size(),
58 axom::fmt::format("Local block index {} exceeds subsystem
size {}
", local_block_index,
59 global_block_indices.size()));
60 remapped_block_indices.push_back(global_block_indices[local_block_index]);
63 addSubsystemSolver(remapped_block_indices, stage.solver, stage.relaxation_factor);
67 std::vector<FieldState> SystemSolver::solve(const std::vector<WeakForm*>& residual_evals,
68 const std::vector<std::vector<size_t>>& block_indices,
69 const FieldState& shape_disp,
70 const std::vector<std::vector<FieldState>>& states,
71 const std::vector<std::vector<FieldState>>& params,
72 const TimeInfo& time_info,
73 const std::vector<const BoundaryConditionManager*>& bc_managers) const
75 SLIC_ERROR_IF(stages_.empty(), "SystemSolver has no stages defined.
");
77 size_t num_residuals = residual_evals.size();
78 std::vector<Stage> active_stages = stages_;
79 for (auto& stage : active_stages) {
80 if (stage.block_indices.empty()) {
81 stage.block_indices.resize(num_residuals);
82 std::iota(stage.block_indices.begin(), stage.block_indices.end(), 0);
84 for (size_t block_index : stage.block_indices) {
85 SLIC_ERROR_IF(block_index >= num_residuals,
86 axom::fmt::format("Stage block index {} exceeds residual count {}
", block_index, num_residuals));
89 // Set the inner tolerance factor based on the number of stages. For single-stage
90 // solves, we don't want to reduce the tolerances as that's pointless and
91 // unintuitive. For multi-stage solves, we want a tighter inner solve to
92 // ensure outer staggered convergence.
93 const double inner_tol_factor = (active_stages.size() == 1) ? 1.0 : 0.6;
94 for (auto& stage : active_stages) {
95 stage.solver->setInnerToleranceMultiplier(inner_tol_factor);
98 // Reset each stage solver's convergence tracking (e.g. initial residual norm for rel-tol)
99 for (const auto& stage : active_stages) {
100 stage.solver->resetConvergenceState();
102 std::vector<NonlinearConvergenceContext> stage_convergence_contexts(active_stages.size());
104 // Working copy of states, updated in-place as stages solve
105 std::vector<std::vector<FieldState>> current_states = states;
107 // Pre-compute name -> (row, slot) routing so the propagation loop avoids O(N*M) string compares
108 // on every staggered iteration. Field-name identity within current_states is invariant across
109 // the iteration loop: only values are replaced, never the underlying name.
110 std::unordered_map<std::string, std::vector<std::pair<size_t, size_t>>> field_routing;
111 for (size_t r = 0; r < num_residuals; ++r) {
112 for (size_t slot = 0; slot < current_states[r].size(); ++slot) {
113 field_routing[current_states[r][slot].get()->name()].emplace_back(r, slot);
117 // Helper lambda to assemble input pointers, evaluate residual, and zero essential BCs
118 auto eval_residual_and_zero_bcs = [&](size_t global_row) {
119 std::vector<const FiniteElementState*> input_ptrs;
120 for (const auto& field_state : current_states[global_row]) {
121 input_ptrs.push_back(field_state.get().get());
123 for (const auto& param_state : params[global_row]) {
124 input_ptrs.push_back(param_state.get().get());
126 mfem::Vector res = residual_evals[global_row]->residual(time_info, shape_disp.get().get(), input_ptrs);
127 if (bc_managers[global_row]) {
128 res.SetSubVector(bc_managers[global_row]->allEssentialTrueDofs(), 0.0);
133 // Evaluate and register true initial residuals before block sweeps mutate the state.
134 for (size_t stage_idx = 0; stage_idx < active_stages.size(); ++stage_idx) {
135 const auto& stage = active_stages[stage_idx];
136 size_t num_stage_blocks = stage.block_indices.size();
137 std::vector<mfem::Vector> stage_init_residuals;
138 for (size_t i = 0; i < num_stage_blocks; ++i) {
139 stage_init_residuals.push_back(eval_residual_and_zero_bcs(stage.block_indices[i]));
141 stage.solver->primeConvergenceContext(stage_init_residuals, stage_convergence_contexts[stage_idx]);
144 for (int iter = 0; iter < max_staggered_iterations_; ++iter) {
145 // --- Run each stage ---
146 for (size_t stage_idx = 0; stage_idx < active_stages.size(); ++stage_idx) {
147 const auto& stage = active_stages[stage_idx];
148 size_t num_stage_blocks = stage.block_indices.size();
150 std::vector<WeakForm*> stage_residuals;
151 std::vector<std::vector<size_t>> stage_block_indices;
152 std::vector<std::vector<FieldState>> stage_states;
153 std::vector<std::vector<FieldState>> stage_params;
154 std::vector<const BoundaryConditionManager*> stage_bc_managers;
156 for (size_t i = 0; i < num_stage_blocks; ++i) {
157 size_t global_row = stage.block_indices[i];
158 stage_residuals.push_back(residual_evals[global_row]);
159 stage_bc_managers.push_back(bc_managers[global_row]);
160 stage_states.push_back(current_states[global_row]);
161 stage_params.push_back(params[global_row]);
163 std::vector<size_t> row_indices(num_stage_blocks, invalid_block_index);
164 for (size_t col_idx = 0; col_idx < num_stage_blocks; ++col_idx) {
165 size_t global_col = stage.block_indices[col_idx];
166 row_indices[col_idx] = block_indices[global_row][global_col];
168 stage_block_indices.push_back(row_indices);
171 std::vector<FieldState> stage_solutions =
172 block_solve(stage_residuals, stage_block_indices, shape_disp, stage_states, stage_params, time_info,
173 stage.solver.get(), stage_bc_managers);
175 // Propagate updated fields to every residual input that references the solved field.
176 // Match by field name (looked up via the pre-computed routing map): coupling fields appear
177 // as fixed inputs in other rows and therefore do not have a valid unknown-block entry there.
178 // Apply relaxation: x_new = omega * x_solved + (1 - omega) * x_k.
179 for (size_t i = 0; i < num_stage_blocks; ++i) {
180 size_t global_col = stage.block_indices[i];
181 FieldState new_state = stage_solutions[i];
183 if (stage.relaxation_factor != 1.0) {
184 FieldState old_state = current_states[global_col][block_indices[global_col][global_col]];
185 new_state = weighted_average(new_state, old_state, stage.relaxation_factor);
188 auto it = field_routing.find(new_state.get()->name());
189 if (it != field_routing.end()) {
190 for (const auto& [r, slot] : it->second) {
191 current_states[r][slot] = new_state;
197 // --- Convergence check (skipped in exact-steps mode, single-iteration mode,
198 // or on the last iteration where a break has no effect) ---
199 if (!exact_staggered_steps_ && max_staggered_iterations_ > 1 && iter < max_staggered_iterations_ - 1) {
200 bool all_converged = true;
201 for (size_t s = 0; s < active_stages.size(); ++s) {
202 const auto& stage = active_stages[s];
203 size_t num_stage_blocks = stage.block_indices.size();
204 std::vector<mfem::Vector> stage_residuals;
205 for (size_t i = 0; i < num_stage_blocks; ++i) {
206 stage_residuals.push_back(eval_residual_and_zero_bcs(stage.block_indices[i]));
208 auto stage_status = stage.solver->convergenceStatus(1.0, stage_residuals, stage_convergence_contexts[s]);
210 if (!stage_status.converged) {
211 all_converged = false;
216 SLIC_INFO_ROOT(axom::fmt::format("Staggered iteration converged after {} iteration(s)
", iter + 1));
222 // Return the diagonal (unknown) states as the final solution
223 std::vector<FieldState> final_solutions;
224 final_solutions.reserve(num_residuals);
225 for (size_t r = 0; r < num_residuals; ++r) {
226 size_t s_idx = block_indices[r][r];
227 final_solutions.push_back(current_states[r][s_idx]);
230 return final_solutions;
233 std::shared_ptr<SystemSolver> SystemSolver::singleBlockSolver(size_t block_index) const
235 constexpr bool exact_staggered_steps = true;
236 for (const auto& stage : stages_) {
237 if (stage.block_indices.empty()) {
238 auto result = std::make_shared<SystemSolver>(1, exact_staggered_steps);
239 std::shared_ptr<NonlinearBlockSolverBase> stage_solver = stage.solver;
240 if (const auto* equation_solver = dynamic_cast<const NonlinearBlockSolver*>(stage.solver.get())) {
241 if (auto cloned_solver = equation_solver->cloneFresh()) {
242 stage_solver = cloned_solver;
245 Stage single_stage{{0}, stage_solver, stage.relaxation_factor};
246 result->addSubsystemSolver(single_stage.block_indices, single_stage.solver, single_stage.relaxation_factor);
250 auto found = std::find(stage.block_indices.begin(), stage.block_indices.end(), block_index);
251 if (found != stage.block_indices.end()) {
252 auto result = std::make_shared<SystemSolver>(1, exact_staggered_steps);
253 std::shared_ptr<NonlinearBlockSolverBase> stage_solver = stage.solver;
254 if (const auto* equation_solver = dynamic_cast<const NonlinearBlockSolver*>(stage.solver.get())) {
255 if (auto cloned_solver = equation_solver->cloneFresh()) {
256 stage_solver = cloned_solver;
259 Stage single_stage{{0}, stage_solver, stage.relaxation_factor};
260 result->addSubsystemSolver(single_stage.block_indices, single_stage.solver, single_stage.relaxation_factor);
This file contains the declaration of the boundary condition manager class.
Orchestrates staggered solution for multiphysics systems.
SystemSolver(std::shared_ptr< NonlinearBlockSolverBase > single_solver)
Construct a monolithic SystemSolver from a single block solver.
void addSubsystemSolver(const std::vector< size_t > &block_indices, std::shared_ptr< NonlinearBlockSolverBase > solver, double relaxation_factor=1.0)
Convenience method to add a solver stage.
Accelerator functionality.
constexpr SMITH_HOST_DEVICE int size(const tensor< T, n... > &)
returns the total number of stored values in a tensor
This file contains nonlinear block solver interfaces and helpers.
Methods for solving systems of equations as given by WeakForms. Tracks these operations on the gretl ...
Represents a single stage in a staggered iteration.