From 2116575a6b13e23d97d6352d2143645e9ea29f7a Mon Sep 17 00:00:00 2001 From: dbochkov-flexcompute Date: Fri, 4 Apr 2025 02:44:55 -0700 Subject: [PATCH] surface monitors --- docs/notebooks | 2 +- tidy3d/__init__.py | 15 + tidy3d/components/data/data_array.py | 74 +++- tidy3d/components/data/dataset.py | 73 +++- tidy3d/components/data/monitor_data.py | 200 ++++++++- tidy3d/components/data/unstructured/base.py | 302 +++++++------ .../components/data/unstructured/surface.py | 399 ++++++++++++++++++ .../data/unstructured/tetrahedral.py | 56 --- tidy3d/components/monitor.py | 155 +++++++ tidy3d/components/types.py | 1 + tidy3d/components/viz.py | 22 + 11 files changed, 1111 insertions(+), 188 deletions(-) create mode 100644 tidy3d/components/data/unstructured/surface.py diff --git a/docs/notebooks b/docs/notebooks index 6767056742..a1815f18b0 160000 --- a/docs/notebooks +++ b/docs/notebooks @@ -1 +1 @@ -Subproject commit 67670567425479be746839b161de7ca4a9a39caa +Subproject commit a1815f18b0e5cd30018ea56d5ba8802376fde8b8 diff --git a/tidy3d/__init__.py b/tidy3d/__init__.py index baa77266e7..63e6930f17 100644 --- a/tidy3d/__init__.py +++ b/tidy3d/__init__.py @@ -124,6 +124,10 @@ FluxTimeDataArray, HeatDataArray, IndexedDataArray, + IndexedFieldDataArray, + IndexedFieldTimeDataArray, + IndexedFreqDataArray, + IndexedTimeDataArray, IndexedVoltageDataArray, ModeAmpsDataArray, ModeIndexDataArray, @@ -158,6 +162,7 @@ PermittivityData, ) from .components.data.sim_data import DATA_TYPE_MAP, SimulationData +from .components.data.unstructured.surface import TriangularSurfaceDataset from .components.data.utils import ( TetrahedralGridDataset, TriangularGridDataset, @@ -284,6 +289,8 @@ ModeSolverMonitor, Monitor, PermittivityMonitor, + SurfaceFieldMonitor, + SurfaceFieldTimeMonitor, ) from .components.parameter_perturbation import ( CustomChargePerturbation, @@ -625,9 +632,14 @@ def set_logging_level(level: str) -> None: "CellDataArray", "IndexedDataArray", "IndexedVoltageDataArray", + "IndexedFieldDataArray", + "IndexedFieldTimeDataArray", + "IndexedFreqDataArray", + "IndexedTimeDataArray", "SteadyVoltageDataArray", "TriangularGridDataset", "TetrahedralGridDataset", + "TriangularSurfaceDataset", "medium_from_nk", "SubpixelSpec", "Staircasing", @@ -676,4 +688,7 @@ def set_logging_level(level: str) -> None: "IsothermalSteadyChargeDCAnalysis", "ChargeToleranceSpec", "AntennaMetricsData", + "SurfaceFieldMonitor", + "SurfaceFieldTimeMonitor", + "TriangularSurfaceDataset", ] diff --git a/tidy3d/components/data/data_array.py b/tidy3d/components/data/data_array.py index 9cabd11796..281e3bd1a8 100644 --- a/tidy3d/components/data/data_array.py +++ b/tidy3d/components/data/data_array.py @@ -1253,6 +1253,66 @@ class SpatialVoltageDataArray(AbstractSpatialDataArray): _dims = ("x", "y", "z", "voltage") +class IndexedFieldDataArray(DataArray): + """Stores indexed values of vector fields in frequency domain. It is typically used + in conjuction with a ``PointDataArray`` to store point-associated vector data. + + Example + ------- + >>> indexed_array = IndexedFieldDataArray( + ... (1+1j) * np.random.random((4,3,1)), coords=dict(index=np.arange(4), axis=np.arange(3), f=[1e9]) + ... ) + """ + + __slots__ = () + _dims = ("index", "axis", "f") + + +class IndexedFieldTimeDataArray(DataArray): + """Stores indexed values of vector fields in time domain. It is typically used + in conjuction with a ``PointDataArray`` to store point-associated vector data. + + Example + ------- + >>> indexed_array = IndexedFieldDataArray( + ... (1+1j) * np.random.random((4,3,1)), coords=dict(index=np.arange(4), axis=np.arange(3), t=[0]) + ... ) + """ + + __slots__ = () + _dims = ("index", "axis", "t") + + +class IndexedFreqDataArray(DataArray): + """Stores indexed values of scalar fields in frequency domain. It is typically used + in conjuction with a ``PointDataArray`` to store point-associated vector data. + + Example + ------- + >>> indexed_array = IndexedFieldDataArray( + ... (1+1j) * np.random.random((4,1)), coords=dict(index=np.arange(4), f=[1e9]) + ... ) + """ + + __slots__ = () + _dims = ("index", "f") + + +class IndexedTimeDataArray(DataArray): + """Stores indexed values of scalar fields in time domain. It is typically used + in conjuction with a ``PointDataArray`` to store point-associated vector data. + + Example + ------- + >>> indexed_array = IndexedFieldDataArray( + ... (1+1j) * np.random.random((4,1)), coords=dict(index=np.arange(4), t=[0]) + ... ) + """ + + __slots__ = () + _dims = ("index", "t") + + DATA_ARRAY_TYPES = [ SpatialDataArray, ScalarFieldDataArray, @@ -1286,7 +1346,19 @@ class SpatialVoltageDataArray(AbstractSpatialDataArray): CellDataArray, IndexedDataArray, IndexedVoltageDataArray, + IndexedFieldDataArray, + IndexedFieldTimeDataArray, + IndexedFreqDataArray, + IndexedTimeDataArray, ] DATA_ARRAY_MAP = {data_array.__name__: data_array for data_array in DATA_ARRAY_TYPES} -IndexedDataArrayTypes = Union[IndexedDataArray, IndexedVoltageDataArray] +IndexedDataArrayTypes = Union[ + IndexedDataArray, + IndexedVoltageDataArray, + IndexedFieldDataArray, + IndexedFieldTimeDataArray, + IndexedFreqDataArray, + IndexedTimeDataArray, + PointDataArray, +] diff --git a/tidy3d/components/data/dataset.py b/tidy3d/components/data/dataset.py index 1e23ceda21..a05653ee2e 100644 --- a/tidy3d/components/data/dataset.py +++ b/tidy3d/components/data/dataset.py @@ -3,7 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union import numpy as np import pydantic.v1 as pd @@ -28,6 +28,7 @@ TimeDataArray, TriangleMeshDataArray, ) +from .unstructured.surface import TriangularSurfaceDataset DEFAULT_MAX_SAMPLES_PER_STEP = 10_000 DEFAULT_MAX_CELLS_PER_STEP = 10_000 @@ -394,6 +395,76 @@ class AuxFieldTimeDataset(AuxFieldDataset): ) +class ElectromagneticSurfaceFieldDataset(AbstractFieldDataset, ABC): + """Stores a collection of E and H fields with x, y, z components.""" + + E: Tuple[Optional[TriangularSurfaceDataset], Optional[TriangularSurfaceDataset]] = pd.Field( + (None, None), + title="E", + description="Spatial distribution of the electric field on the internal and external sides of the surface.", + ) + + H: Tuple[Optional[TriangularSurfaceDataset], Optional[TriangularSurfaceDataset]] = pd.Field( + (None, None), + title="H", + description="Spatial distribution of the magnetic field on the internal and external sides of the surface.", + ) + + normal: TriangularSurfaceDataset = pd.Field( + None, + title="Surface Normal", + description="Spatial distribution of the surface normal.", + ) + + @property + def field_components(self) -> Dict[str, DataArray]: + """Maps the field components to their associated data.""" + fields = { + "E": self.E, + "H": self.H, + } + return {field_name: field for field_name, field in fields.items() if field is not None} + + @property + def current_density(self) -> ElectromagneticSurfaceFieldDataset: + """Surface current density.""" + + h_diff = 0 + template = None + # we assume that is data is None it means field is zero on that side (e.g. PEC) + if self.H[0] is not None: + h_diff += self.H[0].values + template = self.H[0] + if self.H[1] is not None: + h_diff -= self.H[1].values + template = self.H[1] + + if template is None: + raise ValueError( + "Could not calculate current density: the dataset does not contain H field information." + ) + + return template.updated_copy(values=xr.cross(h_diff, self.normal.values, dim="axis")) + + @property + def grid_locations(self) -> Dict[str, str]: + """Maps field components to the string key of their grid locations on the yee lattice.""" + raise RuntimeError("Function 'grid_location' does not apply to surface monitors.") + + @property + def symmetry_eigenvalues(self) -> Dict[str, Callable[[Axis], float]]: + """Maps field components to their (positive) symmetry eigenvalues.""" + + return dict( + Ex=lambda dim: -1 if (dim == 0) else +1, + Ey=lambda dim: -1 if (dim == 1) else +1, + Ez=lambda dim: -1 if (dim == 2) else +1, + Hx=lambda dim: +1 if (dim == 0) else -1, + Hy=lambda dim: +1 if (dim == 1) else -1, + Hz=lambda dim: +1 if (dim == 2) else -1, + ) + + class ModeSolverDataset(ElectromagneticFieldDataset): """Dataset storing scalar components of E and H fields as a function of freq. and mode_index. diff --git a/tidy3d/components/data/monitor_data.py b/tidy3d/components/data/monitor_data.py index 41e3bd2e2e..5783ec33a1 100644 --- a/tidy3d/components/data/monitor_data.py +++ b/tidy3d/components/data/monitor_data.py @@ -36,6 +36,8 @@ ModeSolverMonitor, MonitorType, PermittivityMonitor, + SurfaceFieldMonitor, + SurfaceFieldTimeMonitor, ) from ..source.base import Source from ..source.current import ( @@ -89,11 +91,13 @@ AuxFieldTimeDataset, Dataset, ElectromagneticFieldDataset, + ElectromagneticSurfaceFieldDataset, FieldDataset, FieldTimeDataset, ModeSolverDataset, PermittivityDataset, ) +from .unstructured.surface import TriangularSurfaceDataset Coords1D = ArrayFloat1D @@ -199,7 +203,13 @@ class AbstractFieldData(MonitorData, AbstractFieldDataset, ABC): """Collection of scalar fields with some symmetry properties.""" monitor: Union[ - FieldMonitor, FieldTimeMonitor, AuxFieldTimeMonitor, PermittivityMonitor, ModeMonitor + FieldMonitor, + FieldTimeMonitor, + AuxFieldTimeMonitor, + PermittivityMonitor, + ModeMonitor, + SurfaceFieldMonitor, + SurfaceFieldTimeMonitor, ] symmetry: Tuple[Symmetry, Symmetry, Symmetry] = pd.Field( @@ -1334,6 +1344,192 @@ class AuxFieldTimeData(AuxFieldTimeDataset, AbstractFieldData): _contains_monitor_fields = enforce_monitor_fields_present() +class AbstractSurfaceFieldData(MonitorData, AbstractFieldDataset, ABC): + """Collection of vector fields on a surfacewith some symmetry properties.""" + + monitor: Union[SurfaceFieldMonitor, SurfaceFieldTimeMonitor] + + symmetry: Tuple[Symmetry, Symmetry, Symmetry] = pd.Field( + (0, 0, 0), + title="Symmetry", + description="Symmetry eigenvalues of the original simulation in x, y, and z.", + ) + + symmetry_center: Coordinate = pd.Field( + None, + title="Symmetry Center", + description="Center of the symmetry planes of the original simulation in x, y, and z. " + "Required only if any of the ``symmetry`` field are non-zero.", + ) + + _require_sym_center = required_if_symmetry_present("symmetry_center") + + @property + def symmetry_expanded(self): + """Return the :class:`.AbstractSurfaceFieldData` with fields expanded based on symmetry. If + any symmetry is nonzero (i.e. expanded), the interpolation implicitly creates a copy of the + data array. However, if symmetry is not expanded, the returned array contains a view of + the data, not a copy. + + Returns + ------- + :class:`AbstractSurfaceFieldData` + A data object with the symmetry expanded fields. + """ + + if all(sym == 0 for sym in self.symmetry): + return self + + return self._updated(self._symmetry_update_dict) + + @property + def symmetry_expanded_copy(self) -> AbstractFieldData: + """Create a copy of the :class:`.AbstractSurfaceFieldData` with fields expanded based on symmetry. + + Returns + ------- + :class:`AbstractSurfaceFieldData` + A data object with the symmetry expanded fields. + """ + + if all(sym == 0 for sym in self.symmetry): + return self.copy() + + return self.copy(update=self._symmetry_update_dict) + + @property + def _symmetry_update_dict(self) -> Dict: + """Dictionary of data fields to create data with expanded symmetry.""" + + raise Tidy3dNotImplementedError("Surface monitors currently do not support symmetry.") + + +class ElectromagneticSurfaceFieldData( + AbstractSurfaceFieldData, ElectromagneticSurfaceFieldDataset, ABC +): + """Collection of electromagnetic fields on a surface.""" + + @property + def intensity(self) -> Tuple[TriangularSurfaceDataset, TriangularSurfaceDataset]: + """Return the sum of the squared absolute electric field components.""" + intensity = [None, None] + for ind in range(2): + if self.E[ind] is not None: + e_field = self.E[ind] + intensity[ind] = e_field.norm(dim="axis") ** 2 + return intensity + + @property + def poynting(self) -> Tuple[TriangularSurfaceDataset, TriangularSurfaceDataset]: + """Time-averaged Poynting vector for frequency-domain data.""" + + poynting = [None, None] + for ind in range(2): + if self.E[ind] is not None and self.H[ind] is not None: + e_field = self.E[ind] + h_field = self.H[ind] + + poynting[ind] = e_field.updated_copy( + values=0.5 + * np.real(xr.cross(e_field.values, np.conj(h_field.values), dim="axis")) + ) + + return poynting + + def _check_fields_stored(self, components: list[str]): + """Check that all requested field components are stored in the data.""" + missing_comps = [comp for comp in components if comp not in self.field_components.keys()] + if len(missing_comps) > 0: + raise DataError( + f"Field components {missing_comps} not included in this data object. Use " + "the 'fields' argument of a field monitor to select which components are stored." + ) + + +class SurfaceFieldData(ElectromagneticSurfaceFieldData): + """ + Data associated with a :class:`.SurfaceFieldMonitor`: E and H fields on a surface. + + Example + ------- + >>> from tidy3d import PointDataArray, IndexedFieldDataArray, TriangularSurfaceDataset, CellDataArray + >>> points = PointDataArray([[0, 0, 0], [0, 1, 0], [1, 1, 1]], dims=["index", "axis"]) + >>> cells = CellDataArray([[0, 1, 2]], dims=["cell_index", "vertex_index"]) + >>> values = PointDataArray([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dims=["index", "axis"]) + >>> field_values = IndexedFieldDataArray(np.ones((3, 3, 1)) + 0j, coords={"index": [0, 1, 2], "axis": [0, 1, 2], "f": [1e10]}) + >>> field = TriangularSurfaceDataset(points=points, cells=cells, values=field_values) + >>> normal = TriangularSurfaceDataset(points=points, cells=cells, values=values) + >>> monitor = SurfaceFieldMonitor( + ... size=(2,4,6), freqs=[1e10], name='field', fields=['E', 'H'] + ... ) + >>> data = SurfaceFieldData(monitor=monitor, E=[None, field], H=[None, field], normal=normal) + + """ + + monitor: SurfaceFieldMonitor = pd.Field( + ..., title="Monitor", description="Frequency-domain field monitor associated with the data." + ) + + _contains_monitor_fields = enforce_monitor_fields_present() + + def normalize(self, source_spectrum_fn: Callable[[float], complex]) -> FieldDataset: + """Return copy of self after normalization is applied using source spectrum function.""" + fields_norm = {} + for field_name, field_data in self.field_components.items(): + fields_norm[field_name] = [None, None] + for ind in range(2): + if field_data[ind] is not None: + src_amps = source_spectrum_fn(field_data[ind].values.f) + fields_norm[field_name][ind] = field_data[ind].updated_copy( + values=(field_data[ind].values / src_amps).astype( + field_data[ind].values.dtype + ) + ) + + return self.copy(update=fields_norm) + + +class SurfaceFieldTimeData(ElectromagneticSurfaceFieldData): + """ + + Example + ------- + >>> from tidy3d import PointDataArray, IndexedFieldDataArray, TriangularSurfaceDataset, CellDataArray + >>> points = PointDataArray([[0, 0, 0], [0, 1, 0], [1, 1, 1]], dims=["index", "axis"]) + >>> cells = CellDataArray([[0, 1, 2]], dims=["cell_index", "vertex_index"]) + >>> values = PointDataArray([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dims=["index", "axis"]) + >>> field_values = IndexedFieldDataArray(np.ones((3, 3, 1)) + 0j, coords={"index": [0, 1, 2], "axis": [0, 1, 2], "f": [1e10]}) + >>> field = TriangularSurfaceDataset(points=points, cells=cells, values=field_values) + >>> normal = TriangularSurfaceDataset(points=points, cells=cells, values=values) + >>> monitor = SurfaceFieldTimeMonitor( + ... size=(2,4,6), interval=100, name='field', fields=['E', 'H'] + ... ) + >>> data = SurfaceFieldTimeData(monitor=monitor, E=[None, field], H=[None, field], normal=normal) + """ + + monitor: SurfaceFieldTimeMonitor = pd.Field( + ..., title="Monitor", description="Time-domain field monitor associated with the data." + ) + + _contains_monitor_fields = enforce_monitor_fields_present() + + @property + def poynting(self) -> ScalarFieldTimeDataArray: + """Instantaneous Poynting vector for time-domain data.""" + + poynting = [None, None] + for ind in range(2): + if self.E[ind] is not None and self.H[ind] is not None: + e_field = self.E[ind] + h_field = self.H[ind] + + poynting[ind] = e_field.updated_copy( + values=np.real(xr.cross(e_field.values.real, h_field.values.real, dim="axis")) + ) + + return poynting + + class PermittivityData(PermittivityDataset, AbstractFieldData): """Data for a :class:`.PermittivityMonitor`: diagonal components of the permittivity tensor. @@ -3696,6 +3892,8 @@ def fields_circular_polarization(self) -> xr.Dataset: FieldProjectionAngleData, DiffractionData, DirectivityData, + SurfaceFieldData, + SurfaceFieldTimeData, ) MonitorDataType = Union[MonitorDataTypes] diff --git a/tidy3d/components/data/unstructured/base.py b/tidy3d/components/data/unstructured/base.py index cd27e32707..128f83f7bd 100644 --- a/tidy3d/components/data/unstructured/base.py +++ b/tidy3d/components/data/unstructured/base.py @@ -10,7 +10,7 @@ import pydantic.v1 as pd from xarray import DataArray as XrDataArray -from tidy3d.components.base import cached_property, skip_if_fields_missing +from tidy3d.components.base import Tidy3dBaseModel, cached_property, skip_if_fields_missing from tidy3d.components.data.data_array import ( DATA_ARRAY_MAP, CellDataArray, @@ -19,7 +19,6 @@ PointDataArray, SpatialDataArray, ) -from tidy3d.components.data.dataset import Dataset from tidy3d.components.types import ArrayLike, Axis, Bound from tidy3d.constants import inf from tidy3d.exceptions import DataError, Tidy3dNotImplementedError, ValidationError @@ -31,8 +30,8 @@ DEFAULT_TOLERANCE_CELL_FINDING = 1e-6 -class UnstructuredGridDataset(Dataset, np.lib.mixins.NDArrayOperatorsMixin, ABC): - """Abstract base for datasets that store unstructured grid data.""" +class UnstructuredDataset(Tidy3dBaseModel, np.lib.mixins.NDArrayOperatorsMixin, ABC): + """Abstract base for datasets that store unstructured grid or surface data.""" points: PointDataArray = pd.Field( ..., @@ -367,20 +366,16 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): # Only support operations with a scalar or an unstructured grid dataset of the same spatial dimensionality if not ( isinstance(x, numbers.Number) - or ( - isinstance(x, UnstructuredGridDataset) and x._point_dims() == self._point_dims() - ) + or (isinstance(x, type(self)) and x._point_dims() == self._point_dims()) ): raise Tidy3dNotImplementedError( f"Cannot perform arithmetic operations between instances of different classes ({type(self)} and {type(x)})." ) # Defer to the implementation of the ufunc on unwrapped values. - inputs = tuple(x.values if isinstance(x, UnstructuredGridDataset) else x for x in inputs) + inputs = tuple(x.values if isinstance(x, type(self)) else x for x in inputs) if out: - kwargs["out"] = tuple( - x.values if isinstance(x, UnstructuredGridDataset) else x for x in out - ) + kwargs["out"] = tuple(x.values if isinstance(x, type(self)) else x for x in out) result = getattr(ufunc, method)(*inputs, **kwargs) if type(result) is tuple: @@ -394,20 +389,28 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): return self.updated_copy(values=result) @property - def real(self) -> UnstructuredGridDataset: + def real(self) -> UnstructuredDataset: """Real part of dataset.""" return self.updated_copy(values=self.values.real) @property - def imag(self) -> UnstructuredGridDataset: + def imag(self) -> UnstructuredDataset: """Imaginary part of dataset.""" return self.updated_copy(values=self.values.imag) @property - def abs(self) -> UnstructuredGridDataset: + def abs(self) -> UnstructuredDataset: """Absolute value of dataset.""" return self.updated_copy(values=self.values.abs) + def conj(self) -> UnstructuredDataset: + """Complex conjugate value of dataset.""" + return self.updated_copy(values=self.values.conj()) + + def norm(self, dim) -> UnstructuredDataset: + """Compute vector norm along a given dimension.""" + return self.updated_copy(values=np.sqrt(self.values.dot(self.values.conj(), dim=dim).real)) + """ VTK interfacing """ @classmethod @@ -484,18 +487,55 @@ def _read_vtkUnstructuredGrid(fname: str): return grid @classmethod - @abstractmethod @requires_vtk def _from_vtk_obj( cls, vtk_obj, - field: str = None, + field=None, remove_degenerate_cells: bool = False, remove_unused_points: bool = False, values_type=IndexedDataArray, - expect_complex=None, - ) -> UnstructuredGridDataset: - """Initialize from a vtk object.""" + expect_complex: bool = False, + ) -> UnstructuredDataset: + """Initialize from a vtkUnstructuredGrid instance.""" + + # read point, cells, and values info from a vtk instance + cells_numpy = vtk["vtk_to_numpy"](vtk_obj.GetCells().GetConnectivityArray()) + points_numpy = vtk["vtk_to_numpy"](vtk_obj.GetPoints().GetData()) + values = cls._get_values_from_vtk( + vtk_obj, len(points_numpy), field, values_type, expect_complex + ) + + # verify cell_types + cells_types = vtk["vtk_to_numpy"](vtk_obj.GetCellTypesArray()) + if not np.all(cells_types == cls._vtk_cell_type()): + raise DataError("Only tetrahedral 'vtkUnstructuredGrid' is currently supported") + + # pack point and cell information into Tidy3D arrays + num_cells = len(cells_numpy) // cls._cell_num_vertices() + cells_numpy = np.reshape(cells_numpy, (num_cells, cls._cell_num_vertices())) + + cells = CellDataArray( + cells_numpy, + coords=dict( + cell_index=np.arange(num_cells), vertex_index=np.arange(cls._cell_num_vertices()) + ), + ) + + points = PointDataArray( + points_numpy, + coords=dict(index=np.arange(len(points_numpy)), axis=np.arange(cls._point_dims())), + ) + + if remove_degenerate_cells: + cells = cls._remove_degenerate_cells(cells=cells) + + if remove_unused_points: + points, values, cells = cls._remove_unused_points( + points=points, values=values, cells=cells + ) + + return cls(points=points, cells=cells, values=values) @requires_vtk def _from_vtk_obj_internal( @@ -503,7 +543,7 @@ def _from_vtk_obj_internal( vtk_obj, remove_degenerate_cells: bool = True, remove_unused_points: bool = True, - ) -> UnstructuredGridDataset: + ) -> UnstructuredDataset: """Initialize from a vtk object when performing internal operations. When we do that we pass structure of possibly multidimensional nature of values through parametes field and values_type. We also turn on by default cleaning of geometry.""" @@ -524,13 +564,13 @@ def from_vtu( field: str = None, remove_degenerate_cells: bool = False, remove_unused_points: bool = False, - ) -> UnstructuredGridDataset: + ) -> UnstructuredDataset: """Load unstructured data from a vtu file. Parameters ---------- fname : str - Full path to the .vtu file to load the unstructured data from. + Full path to the .vtu file to load the unstructured dataset from. field : str = None Name of the field to load. remove_degenerate_cells : bool = False @@ -540,8 +580,8 @@ def from_vtu( Returns ------- - UnstructuredGridDataset - Unstructured data. + UnstructuredDataset + Unstructured dataset. """ grid = cls._read_vtkUnstructuredGrid(file) return cls._from_vtk_obj( @@ -659,7 +699,7 @@ def get_cell_volumes(self): @requires_vtk def _plane_slice_raw(self, axis: Axis, pos: float): - """Slice data with a plane and return the resulting VTK object.""" + """Slice dataset with a plane and return the resulting VTK object.""" if pos > self.bounds[1][axis] or pos < self.bounds[0][axis]: raise DataError( @@ -697,9 +737,9 @@ def _plane_slice_raw(self, axis: Axis, pos: float): @abstractmethod @requires_vtk - def plane_slice(self, axis: Axis, pos: float) -> Union[XrDataArray, UnstructuredGridDataset]: - """Slice data with a plane and return the Tidy3D representation of the result - (``UnstructuredGridDataset``). + def plane_slice(self, axis: Axis, pos: float) -> Union[XrDataArray, UnstructuredDataset]: + """Slice dataset with a plane and return the Tidy3D representation of the result + (``UnstructuredDataset``). Parameters ---------- @@ -710,13 +750,13 @@ def plane_slice(self, axis: Axis, pos: float) -> Union[XrDataArray, Unstructured Returns ------- - Union[xarray.DataArray, UnstructuredGridDataset] + Union[xarray.DataArray, UnstructuredDataset] The resulting slice. """ @requires_vtk - def box_clip(self, bounds: Bound) -> UnstructuredGridDataset: - """Clip the unstructured grid using a box defined by ``bounds``. + def box_clip(self, bounds: Bound) -> UnstructuredDataset: + """Clip the unstructured dataset using a box defined by ``bounds``. Parameters ---------- @@ -725,8 +765,8 @@ def box_clip(self, bounds: Bound) -> UnstructuredGridDataset: Returns ------- - UnstructuredGridDataset - Clipped grid. + UnstructuredDataset + Clipped dataset. """ # make and run a VTK clipper @@ -757,10 +797,10 @@ def box_clip(self, bounds: Bound) -> UnstructuredGridDataset: @requires_vtk def reflect( self, axis: Axis, center: float, reflection_only: bool = False - ) -> UnstructuredGridDataset: - """Reflect unstructured data across the plane define by parameters ``axis`` and ``center``. - By default the original data is preserved, setting ``reflection_only`` to ``True`` will - produce only deflected data. + ) -> UnstructuredDataset: + """Reflect unstructured dataset across the plane define by parameters ``axis`` and ``center``. + By default the original dataset is preserved, setting ``reflection_only`` to ``True`` will + produce only reflected dataset. Parameters ---------- @@ -769,12 +809,12 @@ def reflect( center : float Location of the reflection plane along its normal direction. reflection_only : bool = False - Return only reflected data. + Return only reflected dataset. Returns ------- - UnstructuredGridDataset - Data after reflextion is performed. + UnstructuredDataset + Dataset after reflextion is performed. """ reflector = vtk["mod"].vtkReflectionFilter() @@ -789,6 +829,100 @@ def reflect( reflector.GetOutput(), remove_degenerate_cells=False, remove_unused_points=False ) + """ Data selection """ + + @requires_vtk + def sel( + self, + x: Union[float, ArrayLike] = None, + y: Union[float, ArrayLike] = None, + z: Union[float, ArrayLike] = None, + method: Literal["None", "nearest", "pad", "ffill", "backfill", "bfill"] = None, + **sel_kwargs, + ) -> Union[UnstructuredGridDataset, XrDataArray]: + """Extract/interpolate data along one or more spatial or non-spatial directions. Must provide at least one argument + among 'x', 'y', 'z' or non-spatial dimensions through additional arguments. Along spatial dimensions a suitable slicing of + grid is applied (plane slice, line slice, or interpolation). Selection along non-spatial dimensions is forwarded to + .sel() xarray function. Parameter 'method' applies only to non-spatial dimensions. + + Parameters + ---------- + x : Union[float, ArrayLike] = None + x-coordinate of the slice. + y : Union[float, ArrayLike] = None + y-coordinate of the slice. + z : Union[float, ArrayLike] = None + z-coordinate of the slice. + method: Literal[None, "nearest", "pad", "ffill", "backfill", "bfill"] = None + Method to use in xarray sel() function. + **sel_kwargs : dict + Keyword arguments to pass to the xarray sel() function. + + Returns + ------- + Union[TriangularGridDataset, xarray.DataArray] + Extracted data. + """ + + def _non_spatial_sel( + self, + method=None, + **sel_kwargs, + ) -> XrDataArray: + """Select/interpolate data along one or more non-Cartesian directions. + + Parameters + ---------- + **sel_kwargs : dict + Keyword arguments to pass to the xarray sel() function. + + Returns + ------- + xarray.DataArray + Extracted data. + """ + + if "index" in sel_kwargs.keys(): + raise DataError("Cannot select along dimension 'index'.") + + # convert individual values into lists of length 1 + # so that xarray doesn't drop the corresponding dimension + sel_kwargs_only_lists = { + key: value if isinstance(value, list) else [value] for key, value in sel_kwargs.items() + } + return self.updated_copy(values=self.values.sel(**sel_kwargs_only_lists, method=method)) + + def isel( + self, + **sel_kwargs, + ) -> XrDataArray: + """Select data along one or more non-Cartesian directions by coordinate index. + + Parameters + ---------- + **sel_kwargs : dict + Keyword arguments to pass to the xarray isel() function. + + Returns + ------- + xarray.DataArray + Extracted data. + """ + + if "index" in sel_kwargs.keys(): + raise DataError("Cannot select along dimension 'index'.") + + # convert individual values into lists of length 1 + # so that xarray doesn't drop the corresponding dimension + sel_kwargs_only_lists = { + key: value if isinstance(value, list) else [value] for key, value in sel_kwargs.items() + } + return self.updated_copy(values=self.values.isel(**sel_kwargs_only_lists)) + + +class UnstructuredGridDataset(UnstructuredDataset, np.lib.mixins.NDArrayOperatorsMixin, ABC): + """Abstract base for datasets that store unstructured grid data.""" + """ Interpolation """ def interp( @@ -1617,94 +1751,6 @@ def _interp_py_chunk( """ Data selection """ - @requires_vtk - def sel( - self, - x: Union[float, ArrayLike] = None, - y: Union[float, ArrayLike] = None, - z: Union[float, ArrayLike] = None, - method: Literal["None", "nearest", "pad", "ffill", "backfill", "bfill"] = None, - **sel_kwargs, - ) -> Union[UnstructuredGridDataset, XrDataArray]: - """Extract/interpolate data along one or more spatial or non-spatial directions. Must provide at least one argument - among 'x', 'y', 'z' or non-spatial dimensions through additional arguments. Along spatial dimensions a suitable slicing of - grid is applied (plane slice, line slice, or interpolation). Selection along non-spatial dimensions is forwarded to - .sel() xarray function. Parameter 'method' applies only to non-spatial dimensions. - - Parameters - ---------- - x : Union[float, ArrayLike] = None - x-coordinate of the slice. - y : Union[float, ArrayLike] = None - y-coordinate of the slice. - z : Union[float, ArrayLike] = None - z-coordinate of the slice. - method: Literal[None, "nearest", "pad", "ffill", "backfill", "bfill"] = None - Method to use in xarray sel() function. - **sel_kwargs : dict - Keyword arguments to pass to the xarray sel() function. - - Returns - ------- - Union[TriangularGridDataset, xarray.DataArray] - Extracted data. - """ - - def _non_spatial_sel( - self, - method=None, - **sel_kwargs, - ) -> XrDataArray: - """Select/interpolate data along one or more non-Cartesian directions. - - Parameters - ---------- - **sel_kwargs : dict - Keyword arguments to pass to the xarray sel() function. - - Returns - ------- - xarray.DataArray - Extracted data. - """ - - if "index" in sel_kwargs.keys(): - raise DataError("Cannot select along dimension 'index'.") - - # convert individual values into lists of length 1 - # so that xarray doesn't drop the corresponding dimension - sel_kwargs_only_lists = { - key: value if isinstance(value, list) else [value] for key, value in sel_kwargs.items() - } - return self.updated_copy(values=self.values.sel(**sel_kwargs_only_lists, method=method)) - - def isel( - self, - **sel_kwargs, - ) -> XrDataArray: - """Select data along one or more non-Cartesian directions by coordinate index. - - Parameters - ---------- - **sel_kwargs : dict - Keyword arguments to pass to the xarray isel() function. - - Returns - ------- - xarray.DataArray - Extracted data. - """ - - if "index" in sel_kwargs.keys(): - raise DataError("Cannot select along dimension 'index'.") - - # convert individual values into lists of length 1 - # so that xarray doesn't drop the corresponding dimension - sel_kwargs_only_lists = { - key: value if isinstance(value, list) else [value] for key, value in sel_kwargs.items() - } - return self.updated_copy(values=self.values.isel(**sel_kwargs_only_lists)) - @requires_vtk def sel_inside(self, bounds: Bound) -> UnstructuredGridDataset: """Return a new UnstructuredGridDataset that contains the minimal amount data necessary to diff --git a/tidy3d/components/data/unstructured/surface.py b/tidy3d/components/data/unstructured/surface.py new file mode 100644 index 0000000000..430ddd6f55 --- /dev/null +++ b/tidy3d/components/data/unstructured/surface.py @@ -0,0 +1,399 @@ +"""Defines triangular grid datasets.""" + +from __future__ import annotations + +from typing import Dict, Literal, Union + +import numpy as np +import pydantic.v1 as pd + +try: + from matplotlib import pyplot as plt +except ImportError: + pass + +from matplotlib import colormaps +from matplotlib.colors import Normalize +from xarray import DataArray as XrDataArray + +from tidy3d.components.base import cached_property +from tidy3d.components.data.data_array import ( + CellDataArray, + IndexedDataArrayTypes, + PointDataArray, +) +from tidy3d.components.types import ArrayLike, Ax, Axis +from tidy3d.components.viz import add_ax_3d_if_none, equal_aspect +from tidy3d.exceptions import DataError, Tidy3dNotImplementedError +from tidy3d.packaging import requires_vtk, vtk + +from .base import ( + UnstructuredDataset, +) + + +class TriangularSurfaceDataset(UnstructuredDataset): + """Dataset for storing triangulated surface data. Data values are associated with the nodes of + the mesh. + + Note + ---- + To use full functionality of unstructured datasets one must install ``vtk`` package (``pip + install tidy3d[vtk]`` or ``pip install vtk``). Otherwise the functionality of unstructured + datasets is limited to creation, writing to/loading from a file, and arithmetic manipulations. + + Example + ------- + >>> tri_grid_points = PointDataArray( + ... [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]], + ... coords=dict(index=np.arange(4), axis=np.arange(3)), + ... ) + >>> + >>> tri_grid_cells = CellDataArray( + ... [[0, 1, 2], [1, 2, 3]], + ... coords=dict(cell_index=np.arange(2), vertex_index=np.arange(3)), + ... ) + >>> + >>> tri_grid_values = IndexedDataArray( + ... [1.0, 2.0, 3.0, 4.0], coords=dict(index=np.arange(4)), + ... ) + >>> + >>> tri_grid = TriangularSurfaceDataset( + ... points=tri_grid_points, + ... cells=tri_grid_cells, + ... values=tri_grid_values, + ... ) + """ + + points: PointDataArray = pd.Field( + ..., + title="Surface Points", + description="Coordinates of points composing the triangulated surface.", + ) + + values: IndexedDataArrayTypes = pd.Field( + ..., + title="Surface Values", + description="Values stored at the surface points.", + ) + + cells: CellDataArray = pd.Field( + ..., + title="Surface Cells", + description="Cells composing the triangulated surface specified as connections between surface " + "points.", + ) + + """ Fundamental parameters to set up based on grid dimensionality """ + + @classmethod + def _point_dims(cls) -> pd.PositiveInt: + """Dimensionality of stored surface point coordinates.""" + return 3 + + @classmethod + def _cell_num_vertices(cls) -> pd.PositiveInt: + """Number of vertices in a cell.""" + return 3 + + """ Convenience properties """ + + @cached_property + def _points_3d_array(self) -> ArrayLike: + """3D representation of points.""" + return self.points.data + + """ VTK interfacing """ + + @classmethod + @requires_vtk + def _vtk_cell_type(cls): + """VTK cell type to use in the VTK representation.""" + return vtk["mod"].VTK_TRIANGLE + + """ Grid operations """ + + @requires_vtk + def plane_slice(self, axis: Axis, pos: float) -> XrDataArray: + """Slice data with a plane and return the resulting line as a DataArray. + + Parameters + ---------- + axis : Axis + The normal direction of the slicing plane. + pos : float + Position of the slicing plane along its normal direction. + + Returns + ------- + xarray.DataArray + The resulting slice. + """ + + raise Tidy3dNotImplementedError("Slicing of unstructured surfaces is not implemented yet.") + + """ Data selection """ + + @requires_vtk + def sel( + self, + x: Union[float, ArrayLike] = None, + y: Union[float, ArrayLike] = None, + z: Union[float, ArrayLike] = None, + method: Literal["None", "nearest", "pad", "ffill", "backfill", "bfill"] = None, + **sel_kwargs, + ) -> Union[TriangularSurfaceDataset, XrDataArray]: + """Extract/interpolate data along one or more spatial or non-spatial directions. + Currently works only for non-spatial dimensions through additional arguments. + Selection along non-spatial dimensions is forwarded to + .sel() xarray function. Parameter 'method' applies only to non-spatial dimensions. + + Parameters + ---------- + x : Union[float, ArrayLike] = None + x-coordinate of the slice. + y : Union[float, ArrayLike] = None + y-coordinate of the slice. + z : Union[float, ArrayLike] = None + z-coordinate of the slice. + method: Literal[None, "nearest", "pad", "ffill", "backfill", "bfill"] = None + Method to use in xarray sel() function. + **sel_kwargs : dict + Keyword arguments to pass to the xarray sel() function. + + Returns + ------- + Union[TriangularSurfaceDataset, xarray.DataArray] + Extracted data. + """ + + if any(comp is not None for comp in [x, y, z]): + raise Tidy3dNotImplementedError( + "Surface datasets do not support selection along x, y, or z yet." + ) + + return self._non_spatial_sel(method=method, **sel_kwargs) + + def get_cell_volumes(self): + """Get areas associated to each cell of the grid.""" + v0 = self.points[self.cells.sel(vertex_index=0)] + e01 = self.points[self.cells.sel(vertex_index=1)] - v0 + e02 = self.points[self.cells.sel(vertex_index=2)] - v0 + + return 0.5 * np.abs(np.cross(e01, e02)) + + """ Plotting """ + + @equal_aspect + @add_ax_3d_if_none + def plot( + self, + ax: Ax = None, + field: bool = True, + grid: bool = False, + cbar: bool = True, + cmap: str = "viridis", + vmin: float = None, + vmax: float = None, + buffer: float = 0.1, + cbar_kwargs: Dict = None, + ) -> Ax: + """Plot the surface mesh and/or associated data. + + Parameters + ---------- + ax : matplotlib.axes._subplots.Axes = None + matplotlib axes to plot on, if not specified, one is created. + field : bool = True + Whether to plot the data field. + grid : bool = True + Whether to plot the unstructured grid. + cbar : bool = True + Display colorbar (only if ``field == True``). + cmap : str = "viridis" + Color map to use for plotting. + vmin : float = None + The lower bound of data range that the colormap covers. If ``None``, they are + inferred from the data and other keyword arguments. + vmax : float = None + The upper bound of data range that the colormap covers. If ``None``, they are + inferred from the data and other keyword arguments. + buffer : float = 0.1 + Padding around the surface object relative to the diagonal length of the surface bounding box. + cbar_kwargs : Dict = {} + Additional parameters passed to colorbar object. + + Returns + ------- + matplotlib.axes._subplots.Axes + The supplied or created matplotlib axes. + """ + + if cbar_kwargs is None: + cbar_kwargs = {} + if not (field or grid): + raise DataError("Nothing to plot ('field == False', 'grid == False').") + + # plot data field if requested + if field: + if self._num_fields != 1: + raise DataError( + "Unstructured dataset contains more than 1 field. " + "Use '.sel()' to select a single field from available dimensions " + f"{self._values_coords_dict} before plotting." + ) + + face_colors = None + face_alpha = 0 + edge_colors = None + if field: + norm = Normalize() + # np.linalg.norm(field, axis=1) + values_avg = np.mean(self.values.data.ravel()[self.cells.data], axis=1) + face_colors = colormaps[cmap](norm(values_avg)) + face_alpha = 1 + + if grid: + edge_colors = "k" + + plot_obj = ax.plot_trisurf( + self.points.data[:, 0], + self.points.data[:, 1], + self.points.data[:, 2], + triangles=self.cells.data, + fc=face_colors, + ec=edge_colors, + alpha=face_alpha, + # cmap=cmap, + vmin=vmin, + vmax=vmax, + ) + + if field and cbar: + label_kwargs = {} + if "label" not in cbar_kwargs: + label_kwargs["label"] = self.values.name + plt.colorbar(plot_obj, **cbar_kwargs, **label_kwargs) + + # set buffer + if buffer is not None: + bounds = np.array(self.bounds) + size = np.linalg.norm(bounds[1] - bounds[0]) + + ax.set_xlim(bounds[0][0] - buffer * size, bounds[1][0] + buffer * size) + ax.set_ylim(bounds[0][1] - buffer * size, bounds[1][1] + buffer * size) + ax.set_zlim(bounds[0][2] - buffer * size, bounds[1][2] + buffer * size) + + # set labels and titles + ax.set_xlabel("x") + ax.set_ylabel("y") + ax.set_zlabel("z") + # ax.set_title(f"{normal_axis_name} = {self.normal_pos}") + return ax + + @equal_aspect + @add_ax_3d_if_none + def quiver( + self, + ax: Ax = None, + dim: str = "axis", + scale: float = 0.1, + downsampling: int = 1, + buffer: float = 0.1, + color: str = "magnitude", + cbar: bool = True, + cmap: str = "Spectral", + cbar_kwargs: Dict = None, + quiver_kwargs: Dict = None, + ) -> Ax: + """Plot the associated data as quiver plot. Field ``values`` must have length 3 along + the dimension representing x, y, and z components. + + Parameters + ---------- + ax : matplotlib.axes._subplots.Axes = None + matplotlib axes to plot on, if not specified, one is created. + dim : str = "axis" + Dimension along which . + scale : float = 0.1 + Size of arrows relative to the diagonal lentgh of the surface boundaing box. + downsampling : int = 1 + Step for selecting points for plotting (1 for plotting all points). + buffer : float = 0.1 + Padding around the surface object relative to the diagonal length of the surface bounding box. + cbar : bool = True + Display colorbar (only if ``field == True``). + cmap : str = "Spectral" + Color map to use for plotting. + cbar_kwargs : Dict = {} + Additional parameters passed to colorbar object. + quiver_kwargs : Dict = {} + Additional parameters passed to quiver plot function. + + Returns + ------- + matplotlib.axes._subplots.Axes + The supplied or created matplotlib axes. + """ + + if cbar_kwargs is None: + cbar_kwargs = {} + if quiver_kwargs is None: + quiver_kwargs = {} + + # plot data field if requested + if self._num_fields != 3: + raise DataError( + "Unstructured dataset must contain exactly 3 fields for quiver plotting. " + "Use '.sel()' to select a single field from available dimensions " + f"{self._values_coords_dict} before plotting." + ) + + # compute max magnitude of vecotr field + mag = np.sqrt(self.values.dot(self.values.conj(), dim=dim).real) + mag_max = np.max(mag) + # compute max diagonal of dataset + size = np.subtract(self.bounds[1], self.bounds[0]) + diag = np.sqrt(np.sum(size * size)) + # scaling factor + scale_factor = scale * diag / mag_max + u = self.values.sel(**{dim: 0}).real.data[::downsampling] * scale_factor.data + v = self.values.sel(**{dim: 1}).real.data[::downsampling] * scale_factor.data + w = self.values.sel(**{dim: 2}).real.data[::downsampling] * scale_factor.data + + if color == "magnitude": + clr = plt.colormaps[cmap](1 - mag.data[::downsampling].ravel() / mag_max.data) + else: + clr = color + plot_obj = ax.quiver( + self.points.sel(axis=0).data[::downsampling], + self.points.sel(axis=1).data[::downsampling], + self.points.sel(axis=2).data[::downsampling], + u.ravel(), + v.ravel(), + w.ravel(), + color=clr, + **quiver_kwargs, + ) + + if color == "magnitude" and cbar: + label_kwargs = {} + if "label" not in cbar_kwargs: + label_kwargs["label"] = self.values.name + plt.colorbar(plot_obj, **cbar_kwargs, **label_kwargs) + + # set buffer + if buffer is not None: + bounds = np.array(self.bounds) + size = np.linalg.norm(bounds[1] - bounds[0]) + + ax.set_xlim(bounds[0][0] - buffer * size, bounds[1][0] + buffer * size) + ax.set_ylim(bounds[0][1] - buffer * size, bounds[1][1] + buffer * size) + ax.set_zlim(bounds[0][2] - buffer * size, bounds[1][2] + buffer * size) + + # set labels and titles + ax.set_xlabel("x") + ax.set_ylabel("y") + ax.set_zlabel("z") + # ax.set_title(f"{normal_axis_name} = {self.normal_pos}") + return ax diff --git a/tidy3d/components/data/unstructured/tetrahedral.py b/tidy3d/components/data/unstructured/tetrahedral.py index 6c7d8ffc16..d2a0e8481c 100644 --- a/tidy3d/components/data/unstructured/tetrahedral.py +++ b/tidy3d/components/data/unstructured/tetrahedral.py @@ -9,11 +9,6 @@ from xarray import DataArray as XrDataArray from tidy3d.components.base import cached_property -from tidy3d.components.data.data_array import ( - CellDataArray, - IndexedDataArray, - PointDataArray, -) from tidy3d.components.types import ArrayLike, Axis, Bound, Coordinate from tidy3d.exceptions import DataError from tidy3d.packaging import requires_vtk, vtk @@ -87,57 +82,6 @@ def _vtk_cell_type(cls): """VTK cell type to use in the VTK representation.""" return vtk["mod"].VTK_TETRA - @classmethod - @requires_vtk - def _from_vtk_obj( - cls, - vtk_obj, - field=None, - remove_degenerate_cells: bool = False, - remove_unused_points: bool = False, - values_type=IndexedDataArray, - expect_complex: bool = False, - ) -> TetrahedralGridDataset: - """Initialize from a vtkUnstructuredGrid instance.""" - - # read point, cells, and values info from a vtk instance - cells_numpy = vtk["vtk_to_numpy"](vtk_obj.GetCells().GetConnectivityArray()) - points_numpy = vtk["vtk_to_numpy"](vtk_obj.GetPoints().GetData()) - values = cls._get_values_from_vtk( - vtk_obj, len(points_numpy), field, values_type, expect_complex - ) - - # verify cell_types - cells_types = vtk["vtk_to_numpy"](vtk_obj.GetCellTypesArray()) - if not np.all(cells_types == cls._vtk_cell_type()): - raise DataError("Only tetrahedral 'vtkUnstructuredGrid' is currently supported") - - # pack point and cell information into Tidy3D arrays - num_cells = len(cells_numpy) // cls._cell_num_vertices() - cells_numpy = np.reshape(cells_numpy, (num_cells, cls._cell_num_vertices())) - - cells = CellDataArray( - cells_numpy, - coords=dict( - cell_index=np.arange(num_cells), vertex_index=np.arange(cls._cell_num_vertices()) - ), - ) - - points = PointDataArray( - points_numpy, - coords=dict(index=np.arange(len(points_numpy)), axis=np.arange(cls._point_dims())), - ) - - if remove_degenerate_cells: - cells = cls._remove_degenerate_cells(cells=cells) - - if remove_unused_points: - points, values, cells = cls._remove_unused_points( - points=points, values=values, cells=cells - ) - - return cls(points=points, cells=cells, values=values) - """ Grid operations """ @requires_vtk diff --git a/tidy3d/components/monitor.py b/tidy3d/components/monitor.py index f72be0d861..9e2ebca0a5 100644 --- a/tidy3d/components/monitor.py +++ b/tidy3d/components/monitor.py @@ -24,6 +24,7 @@ Coordinate, Direction, EMField, + EMSurfaceField, FreqArray, FreqBound, Literal, @@ -1527,6 +1528,158 @@ def _storage_size_solver(self, num_cells: int, tmesh: ArrayFloat1D) -> int: return BYTES_COMPLEX * num_cells * len(self.freqs) * 6 +class AbstractSurfaceMonitor(Monitor, ABC): + """:class:`Monitor` that records electromagnetic field data as a function of x,y,z on PEC surfaces.""" + + fields: Tuple[EMSurfaceField, ...] = pydantic.Field( + ["E", "H"], + title="Field Components", + description="Collection of field components to store in the monitor.", + ) + + interval_space: Tuple[Literal[1], Literal[1], Literal[1]] = pydantic.Field( + (1, 1, 1), + title="Spatial Interval", + description="Number of grid step intervals between monitor recordings. If equal to 1, " + "there will be no downsampling. If greater than 1, the step will be applied, but the " + "first and last point of the monitor grid are always included.", + ) + + colocate: Literal[False] = pydantic.Field( + False, + title="Colocate Fields", + description="For surface monitors fields are always colocated on surface.", + ) + + +class SurfaceFieldMonitor(AbstractSurfaceMonitor, FreqMonitor): + """:class:`Monitor` that records electromagnetic fields in the frequency domain on PEC surfaces. + + Notes + ----- + + :class:`SurfaceFieldMonitor` objects operate by running a discrete Fourier transform of the fields at a given set of + frequencies to perform the calculation "in-place" with the time stepping. These monitors are specifically designed + to record fields on PEC (perfect electric conductor) surfaces, storing the normal E and tangential H fields. + + Example + ------- + >>> monitor = SurfaceFieldMonitor( + ... center=(1,2,3), + ... size=(2,2,2), + ... fields=['E', 'H'], + ... freqs=[250e12, 300e12], + ... name='surface_monitor') + + """ + + def storage_size(self, num_cells: int, tmesh: ArrayFloat1D) -> int: + """Size of monitor storage given the number of points after discretization. + In general, this is severely overestimated for surface monitors. + """ + + # estimation based on triangulated surface when it crosses cells in xy plane + num_tris = num_cells * 6 + num_points = num_cells * 4 + + # storing 3 coordinate components per point + storage = 3 * BYTES_REAL * num_points + + # storing 3 indices per triangle + storage += 3 * BYTES_REAL * num_tris + + # EH field values + normal field + storage += ( + BYTES_COMPLEX * num_points * len(self.freqs) * len(self.fields) * 3 + + 3 * num_points * BYTES_REAL + ) + + return storage + + def _storage_size_solver(self, num_cells: int, tmesh: ArrayFloat1D) -> int: + """Size of intermediate data recorded by the monitor during a solver run.""" + + # fields + storage = BYTES_COMPLEX * num_cells * len(self.freqs) * len(self.fields) * 3 + + # fields valid map + storage += BYTES_REAL * num_cells * len(self.freqs) * len(self.fields) * 3 + + # auxiliary variables (normals and locations) + storage += BYTES_REAL * num_cells * 7 * 4 + + return storage + + +class SurfaceFieldTimeMonitor(AbstractSurfaceMonitor, TimeMonitor): + """:class:`Monitor` that records electromagnetic fields in the time domain on PEC surfaces. + + Notes + ----- + + :class:`SurfaceFieldTimeMonitor` objects are best used to monitor the time dependence of the fields + on a PEC surface. They can also be used to create “animations” of the field pattern evolution. + + To create an animation, we need to capture the frames at different time instances of the simulation. This can + be done by using a :class:`SurfaceFieldTimeMonitor`. Usually a FDTD simulation contains a large number of time steps + and grid points. Recording the field at every time step and grid point will result in a large dataset. For + the purpose of making animations, this is usually unnecessary. + + + Example + ------- + >>> monitor = SurfaceFieldTimeMonitor( + ... center=(1,2,3), + ... size=(2,2,2), + ... fields=['H'], + ... start=1e-13, + ... stop=5e-13, + ... interval=2, + ... name='movie_monitor') + + """ + + def storage_size(self, num_cells: int, tmesh: ArrayFloat1D) -> int: + """Size of monitor storage given the number of points after discretization. + In general, this is severely overestimated for surface monitors. + """ + num_steps = self.num_steps(tmesh) + + # estimation based on triangulated surface when it crosses cells in xy plane + num_tris = num_cells * 6 + num_points = num_cells * 4 + + # storing 3 coordinate components per point + storage = 3 * BYTES_REAL * num_points + + # storing 3 indices per triangle + storage += 3 * BYTES_REAL * num_tris + + # EH field values + normal field + storage += ( + BYTES_COMPLEX * num_points * num_steps * len(self.fields) * 3 + + 3 * num_points * BYTES_REAL + ) + + return storage + + def _storage_size_solver(self, num_cells: int, tmesh: ArrayFloat1D) -> int: + """Size of intermediate data recorded by the monitor during a solver run.""" + + num_steps = self.num_steps(tmesh) + + # fields + storage = BYTES_COMPLEX * num_cells * num_steps * len(self.fields) * 3 + + # fields valid map + storage += BYTES_REAL * num_cells * num_steps * len(self.fields) * 3 + + # auxiliary variables (normals and locations) + storage += BYTES_REAL * num_cells * 7 * 4 + + return storage + + # types of monitors that are accepted by simulation MonitorType = Union[ FieldMonitor, @@ -1542,4 +1695,6 @@ def _storage_size_solver(self, num_cells: int, tmesh: ArrayFloat1D) -> int: FieldProjectionKSpaceMonitor, DiffractionMonitor, DirectivityMonitor, + SurfaceFieldMonitor, + SurfaceFieldTimeMonitor, ] diff --git a/tidy3d/components/types.py b/tidy3d/components/types.py index 4f4c807be4..1a2db16f57 100644 --- a/tidy3d/components/types.py +++ b/tidy3d/components/types.py @@ -225,6 +225,7 @@ def __modify_schema__(cls, field_schema): """ monitors """ EMField = Literal["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"] +EMSurfaceField = Literal["E", "H"] FieldType = Literal["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"] FreqArray = Union[Tuple[float, ...], ArrayFloat1D] ObsGridArray = Union[Tuple[float, ...], ArrayFloat1D] diff --git a/tidy3d/components/viz.py b/tidy3d/components/viz.py index ff3865a017..75041f67ee 100644 --- a/tidy3d/components/viz.py +++ b/tidy3d/components/viz.py @@ -49,6 +49,12 @@ def make_ax() -> Ax: return ax +def make_ax_3d() -> Ax: + """makes an empty ``ax`` with 3d projection.""" + _, ax = plt.subplots(1, 1, tight_layout=True, subplot_kw={"projection": "3d"}) + return ax + + def add_ax_if_none(plot): """Decorates ``plot(*args, **kwargs, ax=None)`` function. if ax=None in the function call, creates an ax and feeds it to rest of function. @@ -65,6 +71,22 @@ def _plot(*args, **kwargs) -> Ax: return _plot +def add_ax_3d_if_none(plot): + """Decorates ``plot(*args, **kwargs, ax=None)`` function. + if ax=None in the function call, creates an ax with 3d projection and feeds it to rest of function. + """ + + @wraps(plot) + def _plot(*args, **kwargs) -> Ax: + """New plot function using a generated ax if None.""" + if kwargs.get("ax") is None: + ax = make_ax_3d() + kwargs["ax"] = ax + return plot(*args, **kwargs) + + return _plot + + def equal_aspect(plot): """Decorates a plotting function returning a matplotlib axes. Ensures the aspect ratio of the returned axes is set to equal.