7 #include "smith/differentiable_numerics/field_store.hpp"
10 #include "gretl/wang_checkpoint_strategy.hpp"
16 graph_(std::make_shared<gretl::DataStore>(std::make_unique<gretl::WangCheckpointStrategy>(storage_size))),
17 prepend_name_(std::move(prepend_name))
23 if (prepend_name_.empty()) {
26 return prepend_name_ +
"_" + base;
29 std::shared_ptr<DirichletBoundaryConditions> FieldStore::addBoundaryConditions(
FEFieldPtr field)
31 return std::make_shared<DirichletBoundaryConditions>(mesh_->mfemParMesh(), field->space());
36 FieldLabel argument_name_and_index{.field_name = argument_name, .field_index_in_residual = argument_index};
37 if (weak_form_name_to_unknown_name_index_.count(weak_form_name)) {
38 weak_form_name_to_unknown_name_index_.at(weak_form_name).push_back(argument_name_and_index);
40 weak_form_name_to_unknown_name_index_[weak_form_name] = std::vector<FieldLabel>{argument_name_and_index};
47 if (weak_form_name_to_field_names_.count(weak_form_name)) {
48 weak_form_name_to_field_names_.at(weak_form_name).push_back(argument_name);
50 weak_form_name_to_field_names_[weak_form_name] = std::vector<std::string>{argument_name};
52 SLIC_ERROR_IF(argument_index + 1 != weak_form_name_to_field_names_.at(weak_form_name).size(),
53 "Invalid order for adding weak form arguments.");
58 for (
auto& keyval : weak_form_name_to_unknown_name_index_) {
59 std::cout <<
"for residual: " << keyval.first <<
" ";
60 for (
auto& name_index : keyval.second) {
61 std::cout <<
"arg " << name_index.field_name <<
" at " << name_index.field_index_in_residual <<
", ";
63 std::cout << std::endl;
67 std::vector<std::vector<size_t>>
FieldStore::indexMap(
const std::vector<std::string>& residual_names)
const
74 std::map<size_t, size_t> global_state_to_local_col;
75 for (
size_t res_i = 0; res_i < residual_names.size(); ++res_i) {
76 const std::string& res_name = residual_names[res_i];
77 size_t global_state_idx = invalid_block_index;
79 std::string reaction_name;
80 for (
const auto& kv : weak_form_to_test_field_) {
81 if (kv.first == res_name) {
82 reaction_name = kv.second;
88 bool reaction_is_unknown =
false;
89 if (!reaction_name.empty() && weak_form_name_to_unknown_name_index_.count(res_name)) {
90 for (
const auto& label : weak_form_name_to_unknown_name_index_.at(res_name)) {
91 if (label.field_name == reaction_name) {
92 reaction_is_unknown =
true;
98 if (reaction_is_unknown) {
99 global_state_idx = to_states_index_.at(reaction_name);
101 const auto& arg_info = weak_form_name_to_unknown_name_index_.at(res_name);
102 SLIC_ERROR_IF(arg_info.empty(),
103 "Weak form '" << res_name <<
"' has no unknown arguments; cannot build index map.");
104 global_state_idx = to_states_index_.at(arg_info.front().field_name);
106 global_state_to_local_col[global_state_idx] = res_i;
109 std::vector<std::vector<size_t>> block_indices(residual_names.size());
110 for (
size_t res_i = 0; res_i < residual_names.size(); ++res_i) {
111 std::vector<size_t>& res_indices = block_indices[res_i];
112 res_indices = std::vector<size_t>(residual_names.size(), invalid_block_index);
113 const std::string& res_name = residual_names[res_i];
114 const auto& arg_info = weak_form_name_to_unknown_name_index_.at(res_name);
116 for (
const auto& field_name_and_arg_index : arg_info) {
117 size_t global_state_index = to_states_index_.at(field_name_and_arg_index.field_name);
118 auto it = global_state_to_local_col.find(global_state_index);
119 if (it != global_state_to_local_col.end()) {
120 res_indices[it->second] = field_name_and_arg_index.field_index_in_residual;
126 return block_indices;
130 const std::vector<std::string>& weak_form_names)
const
132 std::vector<std::string> field_names;
133 field_names.reserve(weak_form_names.size());
134 for (
const auto& wf_name : weak_form_names) {
141 const std::vector<std::string>& field_names)
const
143 struct BoundaryConditionRef {
144 std::string primary_name;
145 bool use_second_derivative;
147 std::map<std::string, BoundaryConditionRef> field_to_primary;
148 for (
const auto& [_rule, mapping] : time_integration_rules_) {
149 if (!mapping.primary_name.empty()) {
150 field_to_primary[mapping.primary_name] = {mapping.primary_name,
false};
152 if (!mapping.history_name.empty()) {
153 field_to_primary[mapping.history_name] = {mapping.primary_name,
false};
155 if (!mapping.ddot_name.empty()) {
156 field_to_primary[mapping.ddot_name] = {mapping.primary_name,
true};
160 std::vector<const BoundaryConditionManager*> bcs;
161 for (
const auto& field_name : field_names) {
163 auto direct = boundary_conditions_.find(field_name);
164 if (direct != boundary_conditions_.end()) {
165 bcs.push_back(&direct->second->getBoundaryConditionManager());
170 auto ref_it = field_to_primary.find(field_name);
171 if (ref_it == field_to_primary.end()) {
172 bcs.push_back(
nullptr);
175 auto primary_it = boundary_conditions_.find(ref_it->second.primary_name);
176 if (primary_it == boundary_conditions_.end()) {
177 bcs.push_back(
nullptr);
180 const auto& dbc = *primary_it->second;
181 bcs.push_back(ref_it->second.use_second_derivative ? &dbc.getSecondDerivativeManager()
182 : &dbc.getBoundaryConditionManager());
189 const auto resolved_name = resolveFieldName(field_name);
190 if (to_states_index_.count(resolved_name))
return true;
191 if (to_params_index_.count(resolved_name))
return true;
192 if (!shape_disp_.empty() && shape_disp_[0].get()->name() == resolved_name)
return true;
198 const auto resolved_name = resolveFieldName(field_name);
199 if (to_states_index_.count(resolved_name)) {
200 return to_states_index_.at(resolved_name);
202 if (to_params_index_.count(resolved_name)) {
203 return to_params_index_.at(resolved_name);
205 SLIC_ERROR(
"Field or parameter '" << field_name <<
"' not found in getFieldIndex");
211 const auto resolved_name = resolveFieldName(field_name);
213 if (to_states_index_.count(resolved_name)) {
214 size_t field_index = to_states_index_.at(resolved_name);
215 return states_[field_index];
218 if (to_params_index_.count(resolved_name)) {
219 size_t param_index = to_params_index_.at(resolved_name);
220 return params_[param_index];
222 SLIC_ERROR(
"Field or parameter '" << field_name <<
"' not found");
228 const auto resolved_name = resolveFieldName(param_name);
229 size_t param_index = to_params_index_.at(resolved_name);
230 return params_[param_index];
235 const auto resolved_name = resolveFieldName(field_name);
236 if (to_states_index_.count(resolved_name)) {
237 states_[to_states_index_.at(resolved_name)] = updated_field;
240 if (to_params_index_.count(resolved_name)) {
241 params_[to_params_index_.at(resolved_name)] = updated_field;
244 if (!shape_disp_.empty() && shape_disp_[0].get()->name() == resolved_name) {
245 shape_disp_[0] = updated_field;
248 SLIC_ERROR(
"Field '" << field_name <<
"' not found in setField");
251 std::string FieldStore::resolveFieldName(
const std::string& field_name)
const
253 if (to_states_index_.count(field_name) || to_params_index_.count(field_name)) {
256 if (!shape_disp_.empty() && shape_disp_[0].get()->name() == field_name) {
260 const auto prefixed_name =
prefix(field_name);
261 if (prefixed_name != field_name) {
262 if (to_states_index_.count(prefixed_name) || to_params_index_.count(prefixed_name)) {
263 return prefixed_name;
265 if (!shape_disp_.empty() && shape_disp_[0].get()->name() == prefixed_name) {
266 return prefixed_name;
280 SLIC_ERROR_ROOT_IF(weak_form_name_to_field_names_.count(weak_form_name) == 0,
281 axom::fmt::format(
"Weak form '{}' not found in FieldStore. Did you forget to call addReaction()?",
284 auto field_names = weak_form_name_to_field_names_.at(weak_form_name);
285 std::vector<FieldState> fields_for_residual;
286 for (
auto& name : field_names) {
289 to_states_index_.count(name) == 0 && to_params_index_.count(name) == 0,
290 axom::fmt::format(
"Field '{}' (required by weak form '{}') not found in FieldStore", name, weak_form_name));
294 if (to_states_index_.count(name)) {
295 fields_for_residual.push_back(
getField(name));
298 return fields_for_residual;
302 const std::vector<FieldState>& state_fields,
303 const std::vector<FieldState>& param_fields)
const
306 SLIC_ERROR_ROOT_IF(weak_form_name_to_field_names_.count(weak_form_name) == 0,
307 axom::fmt::format(
"Weak form '{}' not found in FieldStore. Did you forget to call addReaction()?",
310 auto field_names = weak_form_name_to_field_names_.at(weak_form_name);
311 std::vector<FieldState> fields_for_residual;
312 for (
auto& name : field_names) {
314 if (to_states_index_.count(name)) {
315 size_t idx = to_states_index_.at(name);
316 SLIC_ERROR_ROOT_IF(idx >= state_fields.size(),
317 axom::fmt::format(
"State field index {} out of bounds (size={}) for field '{}'", idx,
318 state_fields.size(), name));
319 fields_for_residual.push_back(state_fields[idx]);
322 else if (to_params_index_.count(name)) {
323 size_t idx = to_params_index_.at(name);
324 SLIC_ERROR_ROOT_IF(idx >= param_fields.size(),
325 axom::fmt::format(
"Parameter field index {} out of bounds (size={}) for field '{}'", idx,
326 param_fields.size(), name));
327 fields_for_residual.push_back(param_fields[idx]);
329 SLIC_ERROR_ROOT(axom::fmt::format(
"Field or parameter '{}' (required by weak form '{}') not found in FieldStore",
330 name, weak_form_name));
333 return fields_for_residual;
340 auto it = boundary_conditions_.find(field_name);
341 if (it != boundary_conditions_.end()) {
352 return time_integration_rules_;
359 internal_weak_forms_.insert(weak_form_name);
364 for (
auto& kv : weak_form_to_test_field_) {
365 if (kv.first == weak_form_name) {
366 kv.second = field_name;
370 weak_form_to_test_field_.push_back({weak_form_name, field_name});
375 for (
const auto& kv : weak_form_to_test_field_) {
376 if (kv.first == weak_form_name) {
380 SLIC_ERROR(
"Reaction field not found for weak form " << weak_form_name);
390 std::vector<FieldState> output;
391 std::set<std::string> public_static_fields;
392 for (
const auto& [rule, mapping] : time_integration_rules_) {
393 if (mapping.history_name.empty() && mapping.dot_name.empty() && mapping.ddot_name.empty()) {
394 public_static_fields.insert(mapping.primary_name);
397 for (
size_t i = 0; i < states_.size(); ++i) {
398 if (!is_solve_state_[i] || public_static_fields.count(states_[i].get()->name()) > 0) {
399 output.push_back(states_[i]);
407 std::vector<ReactionInfo> infos;
408 for (
const auto& kv : weak_form_to_test_field_) {
409 const std::string& weak_form_name = kv.first;
410 if (internal_weak_forms_.count(weak_form_name)) {
413 const std::string& field_name = kv.second;
414 infos.push_back({weak_form_name, &
getField(field_name).get()->space()});
Contains DirichletBoundaryConditions class for interaction with the differentiable solve interfaces.
Accelerator functionality.
std::shared_ptr< FiniteElementState > FEFieldPtr
typedef
gretl::State< FEFieldPtr, FEDualPtr > FieldState
typedef
Methods for solving systems of equations as given by WeakForms. Tracks these operations on the gretl ...
Mapping between primary and history/derivative fields for time integration.
void printMap()
Print the internal field maps for debugging.
std::vector< FieldState > getOutputFieldStates() const
Get the list of physical, non-solve state fields suitable for output.
const std::vector< FieldState > & getParameterFields() const
Get the list of all parameter fields.
const std::vector< std::pair< std::shared_ptr< TimeIntegrationRule >, TimeIntegrationMapping > > & getTimeIntegrationRules() const
Get all registered time integration rules and their mappings.
bool hasField(const std::string &field_name) const
Check whether a field exists.
const std::vector< FieldState > & getAllFields() const
Get all fields stored in the FieldStore.
std::vector< std::vector< size_t > > indexMap(const std::vector< std::string > &residual_names) const
Generate an index map for the residuals.
std::string prefix(const std::string &base) const
Apply this store's namespace prefix to a base name.
std::string getWeakFormReaction(const std::string &weak_form_name) const
Get the name of the reaction (test) field for a weak form.
const std::shared_ptr< smith::Mesh > & getMesh() const
Get associated mesh shared by all registered fields.
const std::vector< FieldState > & getStateFields() const
Get the list of all state fields.
void setField(const std::string &field_name, FieldState updated_field)
Update a field in the store by name.
void addWeakFormArg(std::string weak_form_name, std::string argument_name, size_t argument_index)
Register an argument to a weak form.
void addWeakFormReaction(std::string weak_form_name, std::string field_name)
Register the reaction (test) field for a weak form.
FieldState getParameter(const std::string ¶m_name) const
Get a parameter field by name.
size_t getFieldIndex(const std::string &field_name) const
Get the internal index of a field by name.
std::vector< FieldState > getStates(const std::string &weak_form_name) const
Get the state fields associated with a weak form.
FieldState getShapeDisp() const
Get the shape displacement field.
std::vector< const BoundaryConditionManager * > getBoundaryConditionManagers(const std::vector< std::string > &weak_form_names) const
Get the boundary condition managers for the given weak forms, one per residual row.
void addWeakFormUnknownArg(std::string weak_form_name, std::string argument_name, size_t argument_index)
Register an argument to a weak form as an unknown.
std::vector< FieldState > getStatesFromVectors(const std::string &weak_form_name, const std::vector< FieldState > &state_fields, const std::vector< FieldState > ¶m_fields) const
Extract state fields for a weak form from provided state and parameter vectors.
std::shared_ptr< DirichletBoundaryConditions > getBoundaryConditions(const std::string &field_name) const
Get the boundary conditions for a given field name.
std::vector< const BoundaryConditionManager * > getBoundaryConditionManagersForFields(const std::vector< std::string > &field_names) const
Get ordered boundary condition managers corresponding to an ordered list of fields.
FieldState getField(const std::string &field_name) const
Get a FieldState by name.
const std::shared_ptr< gretl::DataStore > & graph() const
Get the associated data store graph.
std::vector< ReactionInfo > getReactionInfos() const
Get information about reaction fields.
FieldStore(std::shared_ptr< Mesh > mesh, size_t storage_size=50, std::string prepend_name="")
Construct a new FieldStore object.
void markWeakFormInternal(const std::string &weak_form_name)
Mark a weak form as internal so it is excluded from getReactionInfos().