Skip to content

Commit eb1a1f8

Browse files
PR comments
1 parent 1c6ede7 commit eb1a1f8

File tree

11 files changed

+455
-413
lines changed

11 files changed

+455
-413
lines changed

tests/test_components/test_mode_interp.py

Lines changed: 139 additions & 195 deletions
Large diffs are not rendered by default.

tidy3d/components/data/dataset.py

Lines changed: 136 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
from __future__ import annotations
44

55
from abc import ABC, abstractmethod
6-
from typing import Any, Callable, Optional, Union, get_args
6+
from typing import Any, Callable, Literal, Optional, Union, get_args
77

88
import numpy as np
99
import pydantic.v1 as pd
1010
import xarray as xr
1111

1212
from tidy3d.components.base import Tidy3dBaseModel
13-
from tidy3d.components.types import Axis, xyz
13+
from tidy3d.components.types import Axis, FreqArray, xyz
1414
from tidy3d.constants import C_0, PICOSECOND_PER_NANOMETER_PER_KILOMETER, UnitScaling
1515
from tidy3d.exceptions import DataError
1616
from tidy3d.log import log
@@ -50,6 +50,139 @@ def data_arrs(self) -> dict:
5050
return data_arrs
5151

5252

53+
class FreqDataset(Dataset, ABC):
54+
"""Abstract base class for objects that store collections of `:class:`.DataArray`s."""
55+
56+
def _interp_in_freq_update_dict(
57+
self,
58+
freqs: FreqArray,
59+
method: Literal["linear", "cubic", "cheb"] = "linear",
60+
) -> dict[str, DataArray]:
61+
"""Interpolate mode data to new frequency points.
62+
63+
Interpolates all stored mode data (effective indices, field components, group indices,
64+
and dispersion) from the current frequency grid to a new set of frequencies. This is
65+
useful for obtaining mode data at many frequencies from computations at fewer frequencies,
66+
when modes vary smoothly with frequency.
67+
68+
Parameters
69+
----------
70+
freqs : FreqArray
71+
New frequency points to interpolate to. Should generally span a similar range
72+
as the original frequencies to avoid extrapolation.
73+
method : Literal["linear", "cubic", "cheb"]
74+
Interpolation method. ``"linear"`` for linear interpolation (requires 2+ source
75+
frequencies), ``"cubic"`` for cubic spline interpolation (requires 4+ source
76+
frequencies), ``"cheb"`` for Chebyshev polynomial interpolation using barycentric
77+
formula (requires 3+ source frequencies at Chebyshev nodes).
78+
For complex-valued data, real and imaginary parts are interpolated independently.
79+
80+
Returns
81+
-------
82+
ModeSolverData
83+
New :class:`ModeSolverData` object with data interpolated to the requested frequencies.
84+
85+
Raises
86+
------
87+
DataError
88+
If interpolation parameters are invalid (e.g., too few source frequencies for the
89+
chosen method, or source frequencies not at Chebyshev nodes for 'cheb' method).
90+
91+
Note
92+
----
93+
Interpolation assumes modes vary smoothly with frequency. Results may be inaccurate
94+
near mode crossings or regions of rapid mode variation. Use frequency tracking
95+
(``mode_spec.sort_spec.track_freq``) to help maintain mode ordering consistency.
96+
97+
For Chebyshev interpolation, source frequencies must be at Chebyshev nodes of the
98+
second kind within the frequency range.
99+
100+
Example
101+
-------
102+
>>> # Compute modes at 5 frequencies
103+
>>> import numpy as np
104+
>>> freqs_sparse = np.linspace(1e14, 2e14, 5)
105+
>>> # ... create mode_solver and compute modes ...
106+
>>> # mode_data = mode_solver.solve()
107+
>>> # Interpolate to 50 frequencies
108+
>>> freqs_dense = np.linspace(1e14, 2e14, 50)
109+
>>> # mode_data_interp = mode_data.interp(freqs=freqs_dense, method='linear')
110+
"""
111+
freqs = np.array(freqs)
112+
113+
modify_data = {}
114+
for key, data in self.data_arrs.items():
115+
modify_data[key] = self._interp_dataarray_in_freq(data, freqs, method)
116+
117+
return modify_data
118+
119+
@staticmethod
120+
def _interp_dataarray_in_freq(
121+
data: DataArray,
122+
freqs: FreqArray,
123+
method: Literal["linear", "cubic", "cheb", "nearest"],
124+
) -> DataArray:
125+
"""Interpolate a DataArray along the frequency coordinate.
126+
127+
Parameters
128+
----------
129+
data : DataArray
130+
Data array to interpolate. Must have a frequency coordinate ``"f"``.
131+
freqs : FreqArray
132+
New frequency points.
133+
method : Literal["linear", "cubic", "cheb", "nearest"]
134+
Interpolation method (``"linear"``, ``"cubic"``, ``"cheb"``, or ``"nearest"``).
135+
For ``"cheb"``, uses barycentric formula for Chebyshev interpolation.
136+
137+
Returns
138+
-------
139+
DataArray
140+
Interpolated data array with the same structure but new frequency points.
141+
"""
142+
# Map 'cheb' to xarray's 'barycentric' method
143+
xr_method = "barycentric" if method == "cheb" else method
144+
145+
# Use xarray's built-in interpolation
146+
# For complex data, this automatically interpolates real and imaginary parts
147+
interp_kwargs = {"method": xr_method}
148+
149+
if method == "nearest":
150+
return data.sel(f=freqs, method="nearest")
151+
else:
152+
if method != "cheb":
153+
interp_kwargs["kwargs"] = {"fill_value": "extrapolate"}
154+
return data.interp(f=freqs, **interp_kwargs)
155+
156+
157+
class ModeFreqDataset(FreqDataset, ABC):
158+
"""Abstract base class for objects that store collections of `:class:`.DataArray`s."""
159+
160+
def _apply_mode_reorder(self, sort_inds_2d):
161+
"""Apply a mode reordering along mode_index for all frequency indices.
162+
163+
Parameters
164+
----------
165+
sort_inds_2d : np.ndarray
166+
Array of shape (num_freqs, num_modes) where each row is the
167+
permutation to apply to the mode_index for that frequency.
168+
"""
169+
num_freqs, num_modes = sort_inds_2d.shape
170+
modify_data = {}
171+
for key, data in self.data_arrs.items():
172+
if "mode_index" not in data.dims or "f" not in data.dims:
173+
continue
174+
dims_orig = data.dims
175+
f_coord = data.coords["f"]
176+
slices = []
177+
for ifreq in range(num_freqs):
178+
sl = data.isel(f=ifreq, mode_index=sort_inds_2d[ifreq])
179+
slices.append(sl.assign_coords(mode_index=np.arange(num_modes)))
180+
# Concatenate along the 'f' dimension name and then restore original frequency coordinates
181+
data = xr.concat(slices, dim="f").assign_coords(f=f_coord).transpose(*dims_orig)
182+
modify_data[key] = data
183+
return self.updated_copy(**modify_data)
184+
185+
53186
class AbstractFieldDataset(Dataset, ABC):
54187
"""Collection of scalar fields with some symmetry properties."""
55188

@@ -492,7 +625,7 @@ class AuxFieldTimeDataset(AuxFieldDataset):
492625
)
493626

494627

495-
class ModeSolverDataset(ElectromagneticFieldDataset):
628+
class ModeSolverDataset(ElectromagneticFieldDataset, ModeFreqDataset):
496629
"""Dataset storing scalar components of E and H fields as a function of freq. and mode_index.
497630
498631
Example

tidy3d/components/data/monitor_data.py

Lines changed: 37 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -2393,6 +2393,12 @@ def normalize(self, source_spectrum_fn: Callable[[float], complex]) -> ModeSolve
23932393
"""Return copy of self after normalization is applied using source spectrum function."""
23942394
return self.copy()
23952395

2396+
def _normalize_modes(self):
2397+
"""Normalize modes. Note: this modifies ``self`` in-place."""
2398+
scaling = np.sqrt(np.abs(self.flux))
2399+
for field in self.field_components.values():
2400+
field /= scaling
2401+
23962402
@staticmethod
23972403
def _validate_cheb_nodes(freqs: np.ndarray) -> None:
23982404
"""Validate that frequencies are approximately at Chebyshev nodes.
@@ -2415,22 +2421,21 @@ def _validate_cheb_nodes(freqs: np.ndarray) -> None:
24152421
freqs_sorted = np.sort(freqs)
24162422
expected_sorted = np.sort(expected_freqs)
24172423

2418-
# Check relative error
2419-
freq_range = np.abs(expected_freqs[-1] - expected_freqs[0])
2420-
max_error = np.max(np.abs(freqs_sorted - expected_sorted)) / freq_range
2421-
2422-
if max_error > CHEB_NODES_TOLERANCE:
2424+
# Check if frequencies are close to Chebyshev nodes
2425+
if not np.allclose(
2426+
freqs_sorted, expected_sorted, atol=CHEB_NODES_TOLERANCE, rtol=CHEB_NODES_TOLERANCE
2427+
):
24232428
raise DataError(
2424-
f"For Chebyshev interpolation ('cheb'), source frequencies must be at "
2425-
f"Chebyshev nodes of the second kind. Maximum relative error: {max_error:.2e}, "
2426-
f"tolerance: {CHEB_NODES_TOLERANCE:.2e}. Use ModeInterpSpec.sampling_points() to generate "
2427-
f"appropriate frequencies."
2429+
"For Chebyshev interpolation ('cheb'), source frequencies must be at "
2430+
"Chebyshev nodes of the second kind. Use 'ModeInterpSpec' to generate "
2431+
"appropriate frequencies."
24282432
)
24292433

2430-
def interp(
2434+
def interp_in_freq(
24312435
self,
24322436
freqs: FreqArray,
24332437
method: Literal["linear", "cubic", "cheb"] = "linear",
2438+
renormalize: Optional[bool] = False,
24342439
) -> ModeSolverData:
24352440
"""Interpolate mode data to new frequency points.
24362441
@@ -2450,6 +2455,8 @@ def interp(
24502455
frequencies), ``"cheb"`` for Chebyshev polynomial interpolation using barycentric
24512456
formula (requires 3+ source frequencies at Chebyshev nodes).
24522457
For complex-valued data, real and imaginary parts are interpolated independently.
2458+
renormalize : Optional[bool] = False
2459+
Whether to renormalize the mode profiles to unity power after interpolation.
24532460
24542461
Returns
24552462
-------
@@ -2507,36 +2514,27 @@ def interp(
25072514
f"Invalid interpolation method '{method}'. Use 'linear', 'cubic', or 'cheb'."
25082515
)
25092516

2510-
# Build update dictionary
2511-
update_dict = {}
2512-
2513-
# Interpolate n_complex (required field)
2514-
update_dict["n_complex"] = self._interp_dataarray(self.n_complex, freqs, method)
2515-
2516-
# Interpolate field components if present
2517-
for field_name, field_data in self.field_components.items():
2518-
if field_data is not None:
2519-
update_dict[field_name] = self._interp_dataarray(field_data, freqs, method)
2520-
2521-
# Interpolate n_group_raw if present
2522-
if self.n_group_raw is not None:
2523-
update_dict["n_group_raw"] = self._interp_dataarray(self.n_group_raw, freqs, method)
2517+
# Check if we're extrapolating significantly and warn
2518+
freq_min, freq_max = np.min(source_freqs), np.max(source_freqs)
2519+
new_freq_min, new_freq_max = np.min(freqs), np.max(freqs)
25242520

2525-
# Interpolate dispersion_raw if present
2526-
if self.dispersion_raw is not None:
2527-
update_dict["dispersion_raw"] = self._interp_dataarray(
2528-
self.dispersion_raw, freqs, method
2521+
if new_freq_min < freq_min * (
2522+
1 - MODE_INTERP_EXTRAPOLATION_TOLERANCE
2523+
) or new_freq_max > freq_max * (1 + MODE_INTERP_EXTRAPOLATION_TOLERANCE):
2524+
log.warning(
2525+
f"Interpolating to frequencies outside original range "
2526+
f"[{freq_min:.3e}, {freq_max:.3e}] Hz. New range: "
2527+
f"[{new_freq_min:.3e}, {new_freq_max:.3e}] Hz. "
2528+
"Results may be inaccurate due to extrapolation."
25292529
)
25302530

2531-
# Interpolate grid correction data if present
2532-
for key, data in self._grid_correction_dict.items():
2533-
if isinstance(data, DataArray) and "f" in data.coords:
2534-
update_dict[key] = self._interp_dataarray(data, freqs, method)
2531+
# Build update dictionary
2532+
update_dict = self._interp_in_freq_update_dict(freqs, method)
25352533

25362534
# Handle eps_spec if present - use nearest neighbor interpolation
25372535
if self.eps_spec is not None:
25382536
update_dict["eps_spec"] = list(
2539-
self._interp_dataarray(
2537+
self._interp_dataarray_in_freq(
25402538
FreqDataArray(self.eps_spec, coords={"f": self.monitor.freqs}),
25412539
freqs,
25422540
"nearest",
@@ -2546,57 +2544,13 @@ def interp(
25462544
# Update monitor with new frequencies
25472545
update_dict["monitor"] = self.monitor.updated_copy(freqs=list(freqs))
25482546

2549-
return self.copy(update=update_dict)
2547+
updated_data = self.updated_copy(**update_dict)
2548+
# print(updated_data.poynting)
2549+
# print(updated_data._diff_area)
2550+
if renormalize:
2551+
updated_data._normalize_modes()
25502552

2551-
@staticmethod
2552-
def _interp_dataarray(
2553-
data: DataArray,
2554-
freqs: FreqArray,
2555-
method: str,
2556-
) -> DataArray:
2557-
"""Interpolate a DataArray along the frequency coordinate.
2558-
2559-
Parameters
2560-
----------
2561-
data : DataArray
2562-
Data array to interpolate. Must have a frequency coordinate ``"f"``.
2563-
freqs : FreqArray
2564-
New frequency points.
2565-
method : str
2566-
Interpolation method (``"linear"``, ``"cubic"``, or ``"cheb"``).
2567-
For ``"cheb"``, uses barycentric formula for Chebyshev interpolation.
2568-
2569-
Returns
2570-
-------
2571-
DataArray
2572-
Interpolated data array with the same structure but new frequency points.
2573-
"""
2574-
# Map 'cheb' to xarray's 'barycentric' method
2575-
xr_method = "barycentric" if method == "cheb" else method
2576-
2577-
# Use xarray's built-in interpolation
2578-
# For complex data, this automatically interpolates real and imaginary parts
2579-
interp_kwargs = {"method": xr_method}
2580-
2581-
# Check if we're extrapolating significantly and warn
2582-
freq_min, freq_max = float(data.coords["f"].min()), float(data.coords["f"].max())
2583-
new_freq_min, new_freq_max = float(freqs.min()), float(freqs.max())
2584-
2585-
if new_freq_min < freq_min * (
2586-
1 - MODE_INTERP_EXTRAPOLATION_TOLERANCE
2587-
) or new_freq_max > freq_max * (1 + MODE_INTERP_EXTRAPOLATION_TOLERANCE):
2588-
log.warning(
2589-
f"Interpolating to frequencies outside original range "
2590-
f"[{freq_min:.3e}, {freq_max:.3e}] Hz. New range: "
2591-
f"[{new_freq_min:.3e}, {new_freq_max:.3e}] Hz. "
2592-
"Results may be inaccurate due to extrapolation."
2593-
)
2594-
interp_kwargs["kwargs"] = {"fill_value": "extrapolate"}
2595-
2596-
if method == "nearest":
2597-
return data.sel(f=freqs, method="nearest")
2598-
else:
2599-
return data.interp(f=freqs, **interp_kwargs)
2553+
return updated_data
26002554

26012555
@property
26022556
def time_reversed_copy(self) -> FieldData:

tidy3d/components/microwave/data/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
ImpedanceFreqModeDataArray,
1010
VoltageFreqModeDataArray,
1111
)
12-
from tidy3d.components.data.dataset import Dataset
12+
from tidy3d.components.data.dataset import ModeFreqDataset
1313

1414

15-
class TransmissionLineDataset(Dataset):
15+
class TransmissionLineDataset(ModeFreqDataset):
1616
"""Holds mode data that is specific to transmission lines in microwave and RF applications,
1717
like characteristic impedance.
1818

0 commit comments

Comments
 (0)