"""Configuration system for CTG wave equation solver.
This module provides Pydantic-based configuration classes that support:
- Loading from YAML files
- Automatic validation of types and values
- Resolution of string paths to callable functions
- Handling of arbitrary types (MPI communicators, callables)
Example:
Load configuration from YAML file::
from pathlib import Path
from ctg.config import load_config
config = load_config(Path("config.yaml"))
print(config.numerics.n_cells_space) # 100
Create configuration with string references to functions::
config = AppConfig(
physics={
"initial_data_u": "data.data_functions_pwe:initial_u",
"rhs_0": "data.data_functions_pwe:rhs_0"
}
)
"""
from pathlib import Path
from pydantic import BaseModel, Field, ConfigDict, field_validator
import yaml
from typing import Callable, Any, Union
from importlib import import_module
from functools import partial
import numpy as np
from mpi4py import MPI
def _resolve_callable(v: Union[str, dict, Callable]) -> Callable:
"""Resolve a callable from various input formats.
Supports multiple input formats:
- Direct callable: returned as-is
- String path: "module:function" or "module.submodule.function"
- Dict with path and params: {"path": "module:function", "params": {...}}
Args:
v: Input value that can be:
- A callable object (function, lambda, etc.)
- A string in format "module:function" or "module.function"
- A dict with keys "path" (required) and "params" (optional)
Returns:
Callable: The resolved callable function. If params were provided in dict
format, returns a functools.partial with those params applied.
Raises:
ValueError: If dict input doesn't include 'path' key.
TypeError: If input is not callable, string, or dict.
TypeError: If resolved path doesn't point to a callable object.
ModuleNotFoundError: If the module cannot be imported.
AttributeError: If the function name doesn't exist in the module.
Example:
>>> # String format
>>> func = _resolve_callable("numpy:sin")
>>> func(0) # Returns 0.0
>>> # Dict with parameters
>>> func = _resolve_callable({
... "path": "numpy:power",
... "params": {"exponent": 2}
... })
>>> func(3) # Returns 9 (3^2)
"""
# allow: callable | "mod:func"/"mod.func" | {"path":..., "params": {...}}
if callable(v):
return v
if isinstance(v, str):
path = v
params = {}
elif isinstance(v, dict):
path_or_none: str | None = v.get("path") or v.get("target") or v.get("func")
if not path_or_none:
raise ValueError("dict spec must include 'path'")
path = path_or_none
params = v.get("params", {})
else:
raise TypeError("Expected callable | str | dict with 'path'")
# split "module:function" or "module.sub.func"
if ":" in path:
mod, name = path.split(":", 1)
else:
*mods, name = path.split(".")
mod = ".".join(mods)
obj = getattr(import_module(mod), name)
if not callable(obj):
raise TypeError(f"{path} is not callable")
return partial(obj, **params) if params else obj
[docs]
class physicsCfg(BaseModel):
"""Physics configuration for wave equation problem.
Stores callable functions defining the physics of the wave equation:
initial conditions, boundary conditions, and right-hand side forcing terms.
All functions should accept X with shape (n_points, 2) where columns are [t, x].
Attributes:
exact_sol_u: Exact solution for displacement.
exact_sol_v: Exact solution for verlocity.
initial_data_u: Initial condition for displacement u(x, t=0).
Default returns zeros.
initial_data_v: Initial condition for velocity v(x, t=0).
Default returns zeros.
boundary_data_u: Boundary condition for displacement u on domain boundary.
Default returns zeros.
boundary_data_v: Boundary condition for velocity v on domain boundary.
Default returns zeros.
boundary_D: Dirichlet boundary condition specification.
Default returns zeros.
rhs_0: Right-hand side forcing term for first equation.
Default returns zeros.
rhs_1: Right-hand side forcing term for second equation.
Default returns zeros.
start_time: Start time for the physics simulation.
Default: 0.0.
end_time: End time for the physics simulation.
Default: 1.0.
Note:
Functions can be specified as:
- Direct callables: lambda X: np.sin(X[:, 1])
- String paths: "module.submodule:function_name"
- Dicts with params: {"path": "module:func", "params": {"a": 1}}
- exact_sol_u and exact_sol_v are optional
Example:
>>> physics = physicsCfg(
... initial_data_u="data.data_functions_pwe:initial_u",
... rhs_0=lambda X: np.zeros(X.shape[0]),
... start_time=0.0,
... end_time=2.0
... )
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
exact_sol_u: Callable[..., Any] = Field(
default=lambda X: np.zeros(X.shape[0]),
description="Exact solution for displacement u",
)
exact_sol_v: Callable[..., Any] = Field(
default=lambda X: np.zeros(X.shape[0]),
description="Exact solution for velocity v",
)
initial_data_u: Callable[..., Any] = Field(
default=lambda X: np.zeros(X.shape[0]),
description="Initial condition for displacement u(x, t=0)",
)
initial_data_v: Callable[..., Any] = Field(
default=lambda X: np.zeros(X.shape[0]),
description="Initial condition for velocity v(x, t=0)",
)
boundary_data_u: Callable[..., Any] = Field(
default=lambda X: np.zeros(X.shape[0]), description="Boundary condition for displacement u"
)
boundary_data_v: Callable[..., Any] = Field(
default=lambda X: np.zeros(X.shape[0]), description="Boundary condition for velocity v"
)
boundary_D: Callable[..., Any] = Field(
default=lambda X: np.zeros(X.shape[0]), description="Dirichlet boundary condition"
)
rhs_0: Callable[..., Any] = Field(
default=lambda X: np.zeros(X.shape[0]),
description="Right-hand side forcing term for first equation",
)
rhs_1: Callable[..., Any] = Field(
default=lambda X: np.zeros(X.shape[0]),
description="Right-hand side forcing term for second equation",
)
start_time: float = Field(default=0.0, description="Start time for the physics simulation")
end_time: float = Field(default=1.0, gt=0.0, description="End time for the physics simulation")
@field_validator(
"initial_data_u",
"initial_data_v",
"boundary_data_u",
"boundary_data_v",
"boundary_D",
"rhs_0",
"rhs_1",
"exact_sol_u",
"exact_sol_v",
mode="before",
)
@classmethod
def resolve_callables(cls, v):
"""Resolve a callable specified as a callable, string, or dict."""
return _resolve_callable(v)
[docs]
class numericsCfg(BaseModel):
"""Numerical discretization and solver configuration.
Defines spatial and temporal discretization parameters, random seed,
and MPI communicator for parallel execution.
Attributes:
seed: Random seed for reproducibility. Default: 0.
comm: MPI communicator for parallel execution. Default: MPI.COMM_SELF.
Can be specified as a string path like "module:attribute".
n_cells_space: Number of spatial mesh cells. Default: 100.
order_x: Polynomial degree for spatial finite elements. Default: 1.
t_slab_size: Time slab size for space-time discretization. Default: 0.01.
order_t: Polynomial degree for temporal finite elements. Default: 1.
verbose: Print output from CTGSolver during run. Default: False.
Example:
>>> numerics = numericsCfg(
... n_cells_space=200,
... t_slab_size=0.05,
... comm="data.data_pwe_functions:comm"
... )
"""
# Allow non-standard types like MPI.Comm
model_config = ConfigDict(arbitrary_types_allowed=True)
seed: int = Field(default=0, description="Random seed for reproducibility")
comm: MPI.Comm = Field(default=MPI.COMM_SELF, description="MPI communicator")
n_cells_space: int = Field(default=100, ge=1, description="Number of spatial cells")
order_x: int = Field(default=1, ge=1, description="Spatial FE polynomial degree")
t_slab_size: float = Field(default=0.01, gt=0.0, description="Time slab size")
order_t: int = Field(default=1, ge=1, description="Temporal FE polynomial degree")
verbose: bool = Field(default=False, description="Print output CTGSolver during run")
@field_validator("comm", mode="before")
@classmethod
def resolve_comm(cls, v):
"""Resolve an MPI communicator from a string path or return it."""
if isinstance(v, str):
# split "module:attribute"
if ":" in v:
mod, name = v.split(":", 1)
else:
*mods, name = v.split(".")
mod = ".".join(mods)
return getattr(import_module(mod), name)
return v
[docs]
class postCfg(BaseModel):
"""Post-processing configuration."""
dir_save: str = Field(..., description="Directory where outputs will be saved")
[docs]
class AppConfig(BaseModel):
"""Main application configuration container.
Aggregates all configuration sections for the CTG wave equation solver:
physics setup and numerical parameters.
Attributes:
physics: Physics configuration (initial/boundary conditions, forcing).
numerics: Numerical discretization configuration.
Example::
# Create with defaults
config = AppConfig()
# Load from YAML
config = load_config(Path("config.yaml"))
# Create programmatically
config = AppConfig(
physics={"initial_data_u": "data.funcs:my_initial_u"},
numerics={"n_cells_space": 200, "end_time": 2.0},
)
"""
# Tighten type-checks and allowed variables
model_config = ConfigDict(extra="forbid", frozen=True, validate_default=True)
physics: physicsCfg = Field(default_factory=physicsCfg, description="Physics configuration")
numerics: numericsCfg = Field(
default_factory=numericsCfg, description="Numerical discretization configuration"
)
post: postCfg = Field(default_factory=postCfg, description="Post-processing configuration")
[docs]
def load_config(source: Union[Path, str]) -> AppConfig:
"""Load configuration from a YAML file path or YAML string content.
Accepts either a Path to a YAML file or a string containing YAML content.
Creates a validated AppConfig instance and automatically resolves string
references to callable functions.
Args:
source: Either a Path object pointing to a YAML configuration file,
or a string containing YAML content.
Returns:
AppConfig: Validated configuration object with all settings loaded
and callables resolved.
Raises:
FileNotFoundError: If source is a Path and the file doesn't exist.
yaml.YAMLError: If the YAML content is malformed.
pydantic.ValidationError: If the configuration doesn't match the schema.
Example:
>>> from pathlib import Path
>>> # Load from file path
>>> config = load_config(Path("configs/wave_eq.yaml"))
>>> print(config.numerics.n_cells_space)
100
>>> # Load from YAML string
>>> yaml_str = '''
... physics:
... initial_data_u: "data.data_functions_pwe:initial_u"
... numerics:
... n_cells_space: 100
... '''
>>> config = load_config(yaml_str)
"""
# Determine if source is a Path or string content
if isinstance(source, Path):
yaml_text = source.read_text(encoding="utf-8")
else:
yaml_text = source
data = yaml.safe_load(yaml_text)
return AppConfig(**data)