24 template <
int spatial_dim,
typename OutputSpace,
typename inputs = Parameters<>,
25 typename input_indices = std::make_
integer_sequence<
int, inputs::n>>
35 template <
int spatial_dim,
typename OutputSpace,
typename... InputSpaces,
int... input_indices>
37 std::integer_sequence<int, input_indices...>> :
public WeakForm {
39 using SpacesT = std::vector<const mfem::ParFiniteElementSpace*>;
52 const mfem::ParFiniteElementSpace& output_mfem_space,
const SpacesT& input_mfem_spaces)
53 :
WeakForm(physics_name), mesh_(mesh)
55 std::array<
const mfem::ParFiniteElementSpace*,
sizeof...(InputSpaces)> trial_spaces;
56 std::array<
const mfem::ParFiniteElementSpace*,
sizeof...(InputSpaces) + 1> vector_residual_trial_spaces{
60 sizeof...(InputSpaces) != input_mfem_spaces.size(),
61 std::format(
"{} parameter spaces given in the template argument but {} input mfem spaces were supplied.",
62 sizeof...(InputSpaces), input_mfem_spaces.size()));
65 validateSpace<OutputSpace>(output_mfem_space,
"output");
68 if constexpr (
sizeof...(InputSpaces) > 0) {
70 validateInputSpaces<0>(input_mfem_spaces);
72 for_constexpr<
sizeof...(InputSpaces)>([&](
auto i) { trial_spaces[i] = input_mfem_spaces[i]; });
74 [&](
auto i) { vector_residual_trial_spaces[i + 1] = input_mfem_spaces[i]; });
77 const auto& shape_disp_space = mesh->shapeDisplacementSpace();
79 weak_form_ = std::make_unique<ShapeAwareFunctional<
ShapeDispSpace, OutputSpace(InputSpaces...)>>(
80 &shape_disp_space, &output_mfem_space, trial_spaces);
82 v_dot_weak_form_residual_ =
83 std::make_unique<ShapeAwareFunctional<
ShapeDispSpace, double(OutputSpace, InputSpaces...)>>(
84 &shape_disp_space, vector_residual_trial_spaces);
110 template <
int... active_parameters,
typename BodyIntegralType>
114 mesh_->domain(body_name));
116 v_dot_weak_form_residual_->AddDomainIntegral(
118 [integrand](
double time,
auto X,
auto V,
auto... inputs) {
119 auto orig_tuple = integrand(time, X, inputs...);
120 return smith::inner(get<VALUE>(V), get<VALUE>(orig_tuple)) +
121 smith::inner(get<DERIVATIVE>(V), get<DERIVATIVE>(orig_tuple));
123 mesh_->domain(body_name));
127 template <
typename BodyForceType>
130 addBodyIntegral(
DependsOn<>{}, body_name, body_integral);
153 template <
int... active_parameters,
typename BodyLoadType>
156 addBodyIntegral(depends_on, body_name, [load_function](
double t,
auto X,
auto... inputs) {
162 template <
int... active_parameters,
typename BodyLoadType>
165 return addBodySource(
DependsOn<>{}, body_name, load_function);
191 template <
int... active_parameters,
typename BoundaryIntegrandType>
195 mesh_->domain(boundary_name));
197 v_dot_weak_form_residual_->AddBoundaryIntegral(
199 [integrand](
double t,
auto X,
auto V,
auto... params) {
200 auto orig_surface_flux = integrand(t, X, params...);
203 mesh_->domain(boundary_name));
207 template <
typename BoundaryIntegrandType>
210 addBoundaryIntegral(
DependsOn<>{}, boundary_name, integrand);
231 template <
int... active_parameters,
typename InteriorIntegrandType>
233 InteriorIntegrandType integrand)
236 mesh_->domain(interior_name));
238 v_dot_weak_form_residual_->AddInteriorFaceIntegral(
240 [integrand](
double t,
auto X,
auto V,
auto... params) {
242 auto orig_surface_flux = integrand(t, X, params...);
243 auto [flux_pos, flux_neg] = orig_surface_flux;
246 mesh_->domain(interior_name));
250 template <
typename InteriorIntegrandType>
253 addInteriorBoundaryIntegral(
DependsOn<>{}, interior_name, integrand);
277 template <
int... active_parameters,
typename BoundaryFluxType>
279 BoundaryFluxType flux_function)
281 addBoundaryIntegral(depends_on, boundary_name, [flux_function](
double t,
auto X,
auto... inputs) {
282 auto n =
cross(get<DERIVATIVE>(X));
283 return -flux_function(t, get<VALUE>(X),
normalize(n), get<VALUE>(inputs)...);
288 template <
typename BoundaryFluxType>
291 addBoundaryFlux(
DependsOn<>{}, boundary_name, integrand);
296 [[maybe_unused]]
const std::vector<ConstQuadratureFieldPtr>& quad_fields = {})
const override
298 validateFields(fields,
"residual");
299 dt_ = time_info.
dt();
300 cycle_ = time_info.
cycle();
301 auto ret = (*weak_form_)(time_info.
time(), *shape_disp, *fields[input_indices]...);
308 const std::vector<double>& jacobian_weights,
309 [[maybe_unused]]
const std::vector<ConstQuadratureFieldPtr>& quad_fields = {})
const override
311 validateFields(fields,
"jacobian");
312 dt_ = time_info.
dt();
313 cycle_ = time_info.
cycle();
315 std::unique_ptr<mfem::HypreParMatrix> J;
317 auto addToJ = [&J](
double factor, std::unique_ptr<mfem::HypreParMatrix> jac_contrib) {
319 SLIC_ERROR_IF(J->N() != jac_contrib->N(),
320 "Multiple nonzero jacobian weights are being used on inconsistently sized input arguments.");
321 SLIC_ERROR_IF(J->M() != jac_contrib->M(),
322 "Multiple nonzero jacobian weights are being used on inconsistently sized input arguments.");
323 J->Add(factor, *jac_contrib);
325 J.reset(jac_contrib.release());
326 if (factor != 1.0) (*J) *= factor;
330 auto jacs = jacobianFunctions(std::make_integer_sequence<
int,
sizeof...(input_indices)>{}, time_info.
time(),
333 for (
size_t input_col = 0; input_col < jacobian_weights.size(); ++input_col) {
334 if (jacobian_weights[input_col] != 0.0) {
335 auto K = smith::get<DERIVATIVE>(jacs[input_col](time_info.
time(), shape_disp, fields));
336 addToJ(jacobian_weights[input_col], assemble(K));
345 [[maybe_unused]]
const std::vector<ConstQuadratureFieldPtr>& quad_fields,
346 [[maybe_unused]]
ConstFieldPtr v_shape_disp,
const std::vector<ConstFieldPtr>& v_fields,
347 [[maybe_unused]]
const std::vector<ConstQuadratureFieldPtr>& v_quad_fields,
350 validateFields(fields,
"jvp");
351 SLIC_ERROR_IF(v_fields.size() != fields.size(),
352 "Invalid number of field sensitivities relative to the number of fields");
354 dt_ = time_info.
dt();
355 cycle_ = time_info.
cycle();
357 auto jacs = jacobianFunctions(std::make_integer_sequence<
int,
sizeof...(input_indices)>{}, time_info.
time(),
362 for (
size_t input_col = 0; input_col < fields.size(); ++input_col) {
363 if (v_fields[input_col] !=
nullptr) {
364 auto K = smith::get<DERIVATIVE>(jacs[input_col](time_info.
time(), shape_disp, fields));
365 K.AddMult(*v_fields[input_col], *jvp_reaction);
372 [[maybe_unused]]
const std::vector<ConstQuadratureFieldPtr>& quad_fields,
ConstFieldPtr v_field,
373 DualFieldPtr vjp_shape_disp_sensitivity,
const std::vector<DualFieldPtr>& vjp_sensitivities,
374 [[maybe_unused]]
const std::vector<QuadratureFieldPtr>& vjp_quad_field_sensitivities)
const override
376 validateFields(fields,
"vjp");
377 SLIC_ERROR_IF(vjp_sensitivities.size() != fields.size(),
378 "Invalid number of field sensitivities relative to the number of fields");
380 dt_ = time_info.
dt();
381 cycle_ = time_info.
cycle();
383 auto vecJacs = vectorJacobianFunctions(std::make_integer_sequence<
int,
sizeof...(input_indices)>{},
384 time_info.
time(), shape_disp, v_field, fields);
386 auto shape_vjp = smith::get<DERIVATIVE>((*v_dot_weak_form_residual_)(
388 auto shape_vjp_vector = assemble(shape_vjp);
389 *vjp_shape_disp_sensitivity += *shape_vjp_vector;
392 for (
size_t input_col = 0; input_col < fields.size(); ++input_col) {
393 if (vjp_sensitivities[input_col] !=
nullptr) {
394 auto vec_jac = smith::get<DERIVATIVE>(vecJacs[input_col](time_info.
time(), shape_disp, v_field, fields));
395 auto vec_jac_mfem_vector = assemble(vec_jac);
396 *vjp_sensitivities[input_col] += *vec_jac_mfem_vector;
410 return *v_dot_weak_form_residual_;
418 if constexpr (I <
sizeof...(InputSpaces)) {
420 validateSpace<Space>(*input_mfem_spaces[I], axom::fmt::format(
"input[{}]", I));
421 validateInputSpaces<I + 1>(input_mfem_spaces);
429 if constexpr (I <
sizeof...(InputSpaces)) {
431 validateSpace<Space>(fields[I]->
space(),
432 axom::fmt::format(
"{}(): field[{}] ('{}')", method_name, I, fields[I]->name()));
433 validateFieldsRecursive<I + 1>(fields, method_name);
438 void validateFields(
const std::vector<ConstFieldPtr>& fields,
const std::string& method_name)
const
440 SLIC_ERROR_ROOT_IF(fields.size() !=
sizeof...(InputSpaces),
441 axom::fmt::format(
"{}(): fields.size()={} but weak form expects {} InputSpaces", method_name,
442 fields.size(),
sizeof...(InputSpaces)));
444 if constexpr (
sizeof...(InputSpaces) > 0) {
445 validateFieldsRecursive<0>(fields, method_name);
450 template <
typename Space>
451 static void validateSpace(
const mfem::ParFiniteElementSpace& mfem_space,
const std::string& space_name)
453 const auto* fec = mfem_space.FEColl();
457 if constexpr (Space::family == Family::H1) {
458 std::string fec_name = fec->Name();
460 (fec_name.find(
"H1") != std::string::npos || fec_name.find(
"ND_") != std::string::npos ||
461 fec_name.find(
"Linear") != std::string::npos);
462 SLIC_ERROR_ROOT_IF(!is_h1, axom::fmt::format(
"Space '{}': Template specifies H1 family but mfem space uses '{}'",
463 space_name, fec_name));
464 }
else if constexpr (Space::family == Family::L2) {
465 std::string fec_name = fec->Name();
466 bool is_l2 = (fec_name.find(
"L2") != std::string::npos || fec_name.find(
"DG") != std::string::npos ||
467 fec_name.find(
"Const") != std::string::npos);
469 !is_l2, axom::fmt::format(
"Space '{}': Template specifies L2/DG family but mfem space uses '{}'", space_name,
474 SLIC_ERROR_ROOT_IF(fec->GetOrder() != Space::order,
475 axom::fmt::format(
"Space '{}': Template specifies order {} but mfem space has order {}",
476 space_name, Space::order, fec->GetOrder()));
480 mfem_space.GetVDim() != Space::components,
481 axom::fmt::format(
"Space '{}': Template specifies {} components but mfem space has {} components (VDim)",
482 space_name, Space::components, mfem_space.GetVDim()));
488 const std::vector<ConstFieldPtr>& fs)
const
490 using JacFuncType = std::function<decltype((*weak_form_)(
DifferentiateWRT<1>{}, time, *shape_disp, *fs[i]...))(
492 return std::array<JacFuncType,
sizeof...(i)>{
493 [
this](
double _time,
ConstFieldPtr _shape_disp,
const std::vector<ConstFieldPtr>& _fs) {
501 const std::vector<ConstFieldPtr>& fs)
const
504 std::function<decltype((*v_dot_weak_form_residual_)(
DifferentiateWRT<1>{}, time, *shape_disp, *v, *fs[i]...))(
506 return std::array<GradFuncType,
sizeof...(i)>{
516 mutable size_t cycle_ = 0;
530 inline std::vector<const mfem::ParFiniteElementSpace*>
getSpaces(
const std::vector<smith::FiniteElementState>& states)
532 std::vector<const mfem::ParFiniteElementSpace*>
spaces;
533 for (
auto& f : states) {
534 spaces.push_back(&f.space());
Class for encapsulating the dual vector space of a finite element space (i.e. the space of linear for...
This contains a class that represents the dual of a finite element vector space, i....
This file contains the declaration of structure that manages the MFEM objects that make up the state ...
Smith mesh class which assists in constructing the appropriate parallel mfem meshes and registering a...
Accelerator functionality.
std::vector< const mfem::ParFiniteElementSpace * > getSpaces(const std::vector< smith::FiniteElementState > &states)
Helper function to construct vector of spaces from an existing vector of FiniteElementState.
std::vector< const mfem::ParFiniteElementSpace * > spaces(const std::vector< FieldState > &states, const std::vector< FieldState > ¶ms={})
Get the spaces from the primal fields of a vector of field states.
SMITH_HOST_DEVICE auto cross(const tensor< T, 3, 2 > &A)
compute the cross product of the columns of A: A(:,1) x A(:,2)
SMITH_HOST_DEVICE auto max(dual< gradient_type > a, double b)
Implementation of max for dual numbers.
mfem::future::tuple< T... > tuple
Expose MFEM tuple in the Smith namespace.
constexpr SMITH_HOST_DEVICE auto inner(const dual< S > &A, const dual< T > &B)
mfem::future::tuple_element< I, T > tuple_element
Alias for the MFEM tuple element trait.
mfem::ParFiniteElementSpace & space(FieldState field)
Get the space from the primal field of a field states.
SMITH_HOST_DEVICE auto normalize(const tensor< T, n... > &A)
Normalizes the tensor Each element is divided by the Frobenius norm of the tensor,...
FiniteElementState const * ConstFieldPtr
using
Wrapper of smith::Functional for evaluating integrals and derivatives of quantities with shape displa...
Compile-time alias for a dimension.
a struct that is used in the physics modules to clarify which template arguments are user-controlled ...
struct storing time and timestep information
double dt() const
accessor for dt
size_t cycle() const
accessor for cycle
double time() const
accessor for the current time
A sentinel struct for eliding no-op tensor operations.