Smith  0.1
Smith is an implicit thermal structural mechanics simulation code.
field_store.cpp
1 // Copyright (c) Lawrence Livermore National Security, LLC and
2 // other Smith Project Developers. See the top-level LICENSE file for
3 // details.
4 //
5 // SPDX-License-Identifier: (BSD-3-Clause)
6 
7 #include "smith/differentiable_numerics/field_store.hpp"
10 #include "gretl/wang_checkpoint_strategy.hpp"
11 
12 namespace smith {
13 
14 FieldStore::FieldStore(std::shared_ptr<Mesh> mesh, size_t storage_size, std::string prepend_name)
15  : mesh_(mesh),
16  graph_(std::make_shared<gretl::DataStore>(std::make_unique<gretl::WangCheckpointStrategy>(storage_size))),
17  prepend_name_(std::move(prepend_name))
18 {
19 }
20 
21 std::string FieldStore::prefix(const std::string& base) const
22 {
23  if (prepend_name_.empty()) {
24  return base;
25  }
26  return prepend_name_ + "_" + base;
27 }
28 
29 std::shared_ptr<DirichletBoundaryConditions> FieldStore::addBoundaryConditions(FEFieldPtr field)
30 {
31  return std::make_shared<DirichletBoundaryConditions>(mesh_->mfemParMesh(), field->space());
32 }
33 
34 void FieldStore::addWeakFormUnknownArg(std::string weak_form_name, std::string argument_name, size_t argument_index)
35 {
36  FieldLabel argument_name_and_index{.field_name = argument_name, .field_index_in_residual = argument_index};
37  if (weak_form_name_to_unknown_name_index_.count(weak_form_name)) {
38  weak_form_name_to_unknown_name_index_.at(weak_form_name).push_back(argument_name_and_index);
39  } else {
40  weak_form_name_to_unknown_name_index_[weak_form_name] = std::vector<FieldLabel>{argument_name_and_index};
41  }
42 }
43 
44 void FieldStore::addWeakFormArg(std::string weak_form_name, std::string argument_name, size_t argument_index)
45 {
46  // Store the field name instead of index to avoid confusion between states_ and params_ indices
47  if (weak_form_name_to_field_names_.count(weak_form_name)) {
48  weak_form_name_to_field_names_.at(weak_form_name).push_back(argument_name);
49  } else {
50  weak_form_name_to_field_names_[weak_form_name] = std::vector<std::string>{argument_name};
51  }
52  SLIC_ERROR_IF(argument_index + 1 != weak_form_name_to_field_names_.at(weak_form_name).size(),
53  "Invalid order for adding weak form arguments.");
54 }
55 
57 {
58  for (auto& keyval : weak_form_name_to_unknown_name_index_) {
59  std::cout << "for residual: " << keyval.first << " ";
60  for (auto& name_index : keyval.second) {
61  std::cout << "arg " << name_index.field_name << " at " << name_index.field_index_in_residual << ", ";
62  }
63  std::cout << std::endl;
64  }
65 }
66 
67 std::vector<std::vector<size_t>> FieldStore::indexMap(const std::vector<std::string>& residual_names) const
68 {
69  // Build a local column space: each residual in the subsystem contributes one local column,
70  // corresponding to its "self" diagonal unknown. The self-unknown is preferably the residual's
71  // reaction (test) field if that field appears in the unknown-arg list for this weak form;
72  // otherwise fall back on the first unknown argument (handles cases like the cycle-zero
73  // acceleration solve, where the reaction field is a dependent/history field).
74  std::map<size_t, size_t> global_state_to_local_col;
75  for (size_t res_i = 0; res_i < residual_names.size(); ++res_i) {
76  const std::string& res_name = residual_names[res_i];
77  size_t global_state_idx = invalid_block_index;
78 
79  std::string reaction_name;
80  for (const auto& kv : weak_form_to_test_field_) {
81  if (kv.first == res_name) {
82  reaction_name = kv.second;
83  break;
84  }
85  }
86 
87  // Check if the reaction field is one of the registered unknown args for this weak form.
88  bool reaction_is_unknown = false;
89  if (!reaction_name.empty() && weak_form_name_to_unknown_name_index_.count(res_name)) {
90  for (const auto& label : weak_form_name_to_unknown_name_index_.at(res_name)) {
91  if (label.field_name == reaction_name) {
92  reaction_is_unknown = true;
93  break;
94  }
95  }
96  }
97 
98  if (reaction_is_unknown) {
99  global_state_idx = to_states_index_.at(reaction_name);
100  } else {
101  const auto& arg_info = weak_form_name_to_unknown_name_index_.at(res_name);
102  SLIC_ERROR_IF(arg_info.empty(),
103  "Weak form '" << res_name << "' has no unknown arguments; cannot build index map.");
104  global_state_idx = to_states_index_.at(arg_info.front().field_name);
105  }
106  global_state_to_local_col[global_state_idx] = res_i;
107  }
108 
109  std::vector<std::vector<size_t>> block_indices(residual_names.size());
110  for (size_t res_i = 0; res_i < residual_names.size(); ++res_i) {
111  std::vector<size_t>& res_indices = block_indices[res_i];
112  res_indices = std::vector<size_t>(residual_names.size(), invalid_block_index);
113  const std::string& res_name = residual_names[res_i];
114  const auto& arg_info = weak_form_name_to_unknown_name_index_.at(res_name);
115 
116  for (const auto& field_name_and_arg_index : arg_info) {
117  size_t global_state_index = to_states_index_.at(field_name_and_arg_index.field_name);
118  auto it = global_state_to_local_col.find(global_state_index);
119  if (it != global_state_to_local_col.end()) {
120  res_indices[it->second] = field_name_and_arg_index.field_index_in_residual;
121  }
122  // else: field belongs to a different subsystem; treat as fixed input here.
123  }
124  }
125 
126  return block_indices;
127 }
128 
129 std::vector<const BoundaryConditionManager*> FieldStore::getBoundaryConditionManagers(
130  const std::vector<std::string>& weak_form_names) const
131 {
132  std::vector<std::string> field_names;
133  field_names.reserve(weak_form_names.size());
134  for (const auto& wf_name : weak_form_names) {
135  field_names.push_back(getWeakFormReaction(wf_name));
136  }
137  return getBoundaryConditionManagersForFields(field_names);
138 }
139 
140 std::vector<const BoundaryConditionManager*> FieldStore::getBoundaryConditionManagersForFields(
141  const std::vector<std::string>& field_names) const
142 {
143  struct BoundaryConditionRef {
144  std::string primary_name;
145  bool use_second_derivative;
146  };
147  std::map<std::string, BoundaryConditionRef> field_to_primary;
148  for (const auto& [_rule, mapping] : time_integration_rules_) {
149  if (!mapping.primary_name.empty()) {
150  field_to_primary[mapping.primary_name] = {mapping.primary_name, false};
151  }
152  if (!mapping.history_name.empty()) {
153  field_to_primary[mapping.history_name] = {mapping.primary_name, false};
154  }
155  if (!mapping.ddot_name.empty()) {
156  field_to_primary[mapping.ddot_name] = {mapping.primary_name, true};
157  }
158  }
159 
160  std::vector<const BoundaryConditionManager*> bcs;
161  for (const auto& field_name : field_names) {
162  // Direct DBC entry takes precedence (e.g. an independent unknown like stress with its own BC).
163  auto direct = boundary_conditions_.find(field_name);
164  if (direct != boundary_conditions_.end()) {
165  bcs.push_back(&direct->second->getBoundaryConditionManager());
166  continue;
167  }
168 
169  // Otherwise resolve via the time-integration mapping that owns this reaction field.
170  auto ref_it = field_to_primary.find(field_name);
171  if (ref_it == field_to_primary.end()) {
172  bcs.push_back(nullptr);
173  continue;
174  }
175  auto primary_it = boundary_conditions_.find(ref_it->second.primary_name);
176  if (primary_it == boundary_conditions_.end()) {
177  bcs.push_back(nullptr);
178  continue;
179  }
180  const auto& dbc = *primary_it->second;
181  bcs.push_back(ref_it->second.use_second_derivative ? &dbc.getSecondDerivativeManager()
182  : &dbc.getBoundaryConditionManager());
183  }
184  return bcs;
185 }
186 
187 bool FieldStore::hasField(const std::string& field_name) const
188 {
189  const auto resolved_name = resolveFieldName(field_name);
190  if (to_states_index_.count(resolved_name)) return true;
191  if (to_params_index_.count(resolved_name)) return true;
192  if (!shape_disp_.empty() && shape_disp_[0].get()->name() == resolved_name) return true;
193  return false;
194 }
195 
196 size_t FieldStore::getFieldIndex(const std::string& field_name) const
197 {
198  const auto resolved_name = resolveFieldName(field_name);
199  if (to_states_index_.count(resolved_name)) {
200  return to_states_index_.at(resolved_name);
201  }
202  if (to_params_index_.count(resolved_name)) {
203  return to_params_index_.at(resolved_name);
204  }
205  SLIC_ERROR("Field or parameter '" << field_name << "' not found in getFieldIndex");
206  return 0; // unreachable
207 }
208 
209 FieldState FieldStore::getField(const std::string& field_name) const
210 {
211  const auto resolved_name = resolveFieldName(field_name);
212  // Check if it's a state field
213  if (to_states_index_.count(resolved_name)) {
214  size_t field_index = to_states_index_.at(resolved_name);
215  return states_[field_index];
216  }
217  // Otherwise check if it's a parameter
218  if (to_params_index_.count(resolved_name)) {
219  size_t param_index = to_params_index_.at(resolved_name);
220  return params_[param_index];
221  }
222  SLIC_ERROR("Field or parameter '" << field_name << "' not found");
223  return states_[0]; // unreachable, but needed for compilation
224 }
225 
226 FieldState FieldStore::getParameter(const std::string& param_name) const
227 {
228  const auto resolved_name = resolveFieldName(param_name);
229  size_t param_index = to_params_index_.at(resolved_name);
230  return params_[param_index];
231 }
232 
233 void FieldStore::setField(const std::string& field_name, FieldState updated_field)
234 {
235  const auto resolved_name = resolveFieldName(field_name);
236  if (to_states_index_.count(resolved_name)) {
237  states_[to_states_index_.at(resolved_name)] = updated_field;
238  return;
239  }
240  if (to_params_index_.count(resolved_name)) {
241  params_[to_params_index_.at(resolved_name)] = updated_field;
242  return;
243  }
244  if (!shape_disp_.empty() && shape_disp_[0].get()->name() == resolved_name) {
245  shape_disp_[0] = updated_field;
246  return;
247  }
248  SLIC_ERROR("Field '" << field_name << "' not found in setField");
249 }
250 
251 std::string FieldStore::resolveFieldName(const std::string& field_name) const
252 {
253  if (to_states_index_.count(field_name) || to_params_index_.count(field_name)) {
254  return field_name;
255  }
256  if (!shape_disp_.empty() && shape_disp_[0].get()->name() == field_name) {
257  return field_name;
258  }
259 
260  const auto prefixed_name = prefix(field_name);
261  if (prefixed_name != field_name) {
262  if (to_states_index_.count(prefixed_name) || to_params_index_.count(prefixed_name)) {
263  return prefixed_name;
264  }
265  if (!shape_disp_.empty() && shape_disp_[0].get()->name() == prefixed_name) {
266  return prefixed_name;
267  }
268  }
269 
270  return field_name;
271 }
272 
273 FieldState FieldStore::getShapeDisp() const { return shape_disp_[0]; }
274 
275 const std::vector<FieldState>& FieldStore::getAllFields() const { return states_; }
276 
277 std::vector<FieldState> FieldStore::getStates(const std::string& weak_form_name) const
278 {
279  // Validate that weak form is registered
280  SLIC_ERROR_ROOT_IF(weak_form_name_to_field_names_.count(weak_form_name) == 0,
281  axom::fmt::format("Weak form '{}' not found in FieldStore. Did you forget to call addReaction()?",
282  weak_form_name));
283 
284  auto field_names = weak_form_name_to_field_names_.at(weak_form_name);
285  std::vector<FieldState> fields_for_residual;
286  for (auto& name : field_names) {
287  // Validate that field exists
288  SLIC_ERROR_ROOT_IF(
289  to_states_index_.count(name) == 0 && to_params_index_.count(name) == 0,
290  axom::fmt::format("Field '{}' (required by weak form '{}') not found in FieldStore", name, weak_form_name));
291 
292  // Only include state fields, not parameters
293  // Parameters are passed separately to avoid duplication in block_solve
294  if (to_states_index_.count(name)) {
295  fields_for_residual.push_back(getField(name));
296  }
297  }
298  return fields_for_residual;
299 }
300 
301 std::vector<FieldState> FieldStore::getStatesFromVectors(const std::string& weak_form_name,
302  const std::vector<FieldState>& state_fields,
303  const std::vector<FieldState>& param_fields) const
304 {
305  // Validate that weak form is registered
306  SLIC_ERROR_ROOT_IF(weak_form_name_to_field_names_.count(weak_form_name) == 0,
307  axom::fmt::format("Weak form '{}' not found in FieldStore. Did you forget to call addReaction()?",
308  weak_form_name));
309 
310  auto field_names = weak_form_name_to_field_names_.at(weak_form_name);
311  std::vector<FieldState> fields_for_residual;
312  for (auto& name : field_names) {
313  // Check if it's a state field
314  if (to_states_index_.count(name)) {
315  size_t idx = to_states_index_.at(name);
316  SLIC_ERROR_ROOT_IF(idx >= state_fields.size(),
317  axom::fmt::format("State field index {} out of bounds (size={}) for field '{}'", idx,
318  state_fields.size(), name));
319  fields_for_residual.push_back(state_fields[idx]);
320  }
321  // Otherwise check if it's a parameter
322  else if (to_params_index_.count(name)) {
323  size_t idx = to_params_index_.at(name);
324  SLIC_ERROR_ROOT_IF(idx >= param_fields.size(),
325  axom::fmt::format("Parameter field index {} out of bounds (size={}) for field '{}'", idx,
326  param_fields.size(), name));
327  fields_for_residual.push_back(param_fields[idx]);
328  } else {
329  SLIC_ERROR_ROOT(axom::fmt::format("Field or parameter '{}' (required by weak form '{}') not found in FieldStore",
330  name, weak_form_name));
331  }
332  }
333  return fields_for_residual;
334 }
335 
336 const std::shared_ptr<smith::Mesh>& FieldStore::getMesh() const { return mesh_; }
337 
338 std::shared_ptr<DirichletBoundaryConditions> FieldStore::getBoundaryConditions(const std::string& field_name) const
339 {
340  auto it = boundary_conditions_.find(field_name);
341  if (it != boundary_conditions_.end()) {
342  return it->second;
343  }
344  return nullptr;
345 }
346 
347 const std::shared_ptr<gretl::DataStore>& FieldStore::graph() const { return graph_; }
348 
349 const std::vector<std::pair<std::shared_ptr<TimeIntegrationRule>, FieldStore::TimeIntegrationMapping>>&
351 {
352  return time_integration_rules_;
353 }
354 
355 void FieldStore::setField(size_t index, FieldState updated_field) { states_[index] = updated_field; }
356 
357 void FieldStore::markWeakFormInternal(const std::string& weak_form_name)
358 {
359  internal_weak_forms_.insert(weak_form_name);
360 }
361 
362 void FieldStore::addWeakFormReaction(std::string weak_form_name, std::string field_name)
363 {
364  for (auto& kv : weak_form_to_test_field_) {
365  if (kv.first == weak_form_name) {
366  kv.second = field_name;
367  return;
368  }
369  }
370  weak_form_to_test_field_.push_back({weak_form_name, field_name});
371 }
372 
373 std::string FieldStore::getWeakFormReaction(const std::string& weak_form_name) const
374 {
375  for (const auto& kv : weak_form_to_test_field_) {
376  if (kv.first == weak_form_name) {
377  return kv.second;
378  }
379  }
380  SLIC_ERROR("Reaction field not found for weak form " << weak_form_name);
381  return "";
382 }
383 
384 const std::vector<FieldState>& FieldStore::getParameterFields() const { return params_; }
385 
386 const std::vector<FieldState>& FieldStore::getStateFields() const { return states_; }
387 
388 std::vector<FieldState> FieldStore::getOutputFieldStates() const
389 {
390  std::vector<FieldState> output;
391  std::set<std::string> public_static_fields;
392  for (const auto& [rule, mapping] : time_integration_rules_) {
393  if (mapping.history_name.empty() && mapping.dot_name.empty() && mapping.ddot_name.empty()) {
394  public_static_fields.insert(mapping.primary_name);
395  }
396  }
397  for (size_t i = 0; i < states_.size(); ++i) {
398  if (!is_solve_state_[i] || public_static_fields.count(states_[i].get()->name()) > 0) {
399  output.push_back(states_[i]);
400  }
401  }
402  return output;
403 }
404 
405 std::vector<ReactionInfo> FieldStore::getReactionInfos() const
406 {
407  std::vector<ReactionInfo> infos;
408  for (const auto& kv : weak_form_to_test_field_) {
409  const std::string& weak_form_name = kv.first;
410  if (internal_weak_forms_.count(weak_form_name)) {
411  continue;
412  }
413  const std::string& field_name = kv.second;
414  infos.push_back({weak_form_name, &getField(field_name).get()->space()});
415  }
416  return infos;
417 }
418 
419 } // namespace smith
Contains DirichletBoundaryConditions class for interaction with the differentiable solve interfaces.
Accelerator functionality.
Definition: smith.cpp:36
std::shared_ptr< FiniteElementState > FEFieldPtr
typedef
Definition: field_state.hpp:20
gretl::State< FEFieldPtr, FEDualPtr > FieldState
typedef
Definition: field_state.hpp:22
Methods for solving systems of equations as given by WeakForms. Tracks these operations on the gretl ...
Mapping between primary and history/derivative fields for time integration.
void printMap()
Print the internal field maps for debugging.
Definition: field_store.cpp:56
std::vector< FieldState > getOutputFieldStates() const
Get the list of physical, non-solve state fields suitable for output.
const std::vector< FieldState > & getParameterFields() const
Get the list of all parameter fields.
const std::vector< std::pair< std::shared_ptr< TimeIntegrationRule >, TimeIntegrationMapping > > & getTimeIntegrationRules() const
Get all registered time integration rules and their mappings.
bool hasField(const std::string &field_name) const
Check whether a field exists.
const std::vector< FieldState > & getAllFields() const
Get all fields stored in the FieldStore.
std::vector< std::vector< size_t > > indexMap(const std::vector< std::string > &residual_names) const
Generate an index map for the residuals.
Definition: field_store.cpp:67
std::string prefix(const std::string &base) const
Apply this store's namespace prefix to a base name.
Definition: field_store.cpp:21
std::string getWeakFormReaction(const std::string &weak_form_name) const
Get the name of the reaction (test) field for a weak form.
const std::shared_ptr< smith::Mesh > & getMesh() const
Get associated mesh shared by all registered fields.
const std::vector< FieldState > & getStateFields() const
Get the list of all state fields.
void setField(const std::string &field_name, FieldState updated_field)
Update a field in the store by name.
void addWeakFormArg(std::string weak_form_name, std::string argument_name, size_t argument_index)
Register an argument to a weak form.
Definition: field_store.cpp:44
void addWeakFormReaction(std::string weak_form_name, std::string field_name)
Register the reaction (test) field for a weak form.
FieldState getParameter(const std::string &param_name) const
Get a parameter field by name.
size_t getFieldIndex(const std::string &field_name) const
Get the internal index of a field by name.
std::vector< FieldState > getStates(const std::string &weak_form_name) const
Get the state fields associated with a weak form.
FieldState getShapeDisp() const
Get the shape displacement field.
std::vector< const BoundaryConditionManager * > getBoundaryConditionManagers(const std::vector< std::string > &weak_form_names) const
Get the boundary condition managers for the given weak forms, one per residual row.
void addWeakFormUnknownArg(std::string weak_form_name, std::string argument_name, size_t argument_index)
Register an argument to a weak form as an unknown.
Definition: field_store.cpp:34
std::vector< FieldState > getStatesFromVectors(const std::string &weak_form_name, const std::vector< FieldState > &state_fields, const std::vector< FieldState > &param_fields) const
Extract state fields for a weak form from provided state and parameter vectors.
std::shared_ptr< DirichletBoundaryConditions > getBoundaryConditions(const std::string &field_name) const
Get the boundary conditions for a given field name.
std::vector< const BoundaryConditionManager * > getBoundaryConditionManagersForFields(const std::vector< std::string > &field_names) const
Get ordered boundary condition managers corresponding to an ordered list of fields.
FieldState getField(const std::string &field_name) const
Get a FieldState by name.
const std::shared_ptr< gretl::DataStore > & graph() const
Get the associated data store graph.
std::vector< ReactionInfo > getReactionInfos() const
Get information about reaction fields.
FieldStore(std::shared_ptr< Mesh > mesh, size_t storage_size=50, std::string prepend_name="")
Construct a new FieldStore object.
Definition: field_store.cpp:14
void markWeakFormInternal(const std::string &weak_form_name)
Mark a weak form as internal so it is excluded from getReactionInfos().