Skip to content

Commit dc4be8a

Browse files
recalculation of grid correction factors
1 parent de5b620 commit dc4be8a

File tree

4 files changed

+181
-40
lines changed

4 files changed

+181
-40
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3333
- Added a Gaussian inverse design filter option with autograd gradients and complete padding mode coverage.
3434
- Added support for argument passing to DRC file when running checks with `DRCRunner.run(..., drc_args={key: value})` in klayout plugin.
3535
- Added support for `nonlinear_spec` in `CustomMedium` and `CustomDispersiveMedium`.
36+
- Added `interp_spec` in `ModeSpec` to allow downsampling and interpolation of waveguide modes in frequency.
3637

3738
### Breaking Changes
3839
- Edge singularity correction at PEC and lossy metal edges defaults to `True`.

docs/api/mode.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,6 @@ Mode Specifications
77
:toctree: _autosummary/
88
:template: module.rst
99

10-
tidy3d.ModeSpec
10+
tidy3d.ModeSpec
11+
tidy3d.ModeSortSpec
12+
tidy3d.ModeInterpSpec

tidy3d/components/data/monitor_data.py

Lines changed: 121 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from tidy3d.components.base_sim.data.monitor_data import AbstractMonitorData
1919
from tidy3d.components.grid.grid import Coords, Grid
2020
from tidy3d.components.medium import Medium, MediumType
21-
from tidy3d.components.mode_spec import ModeInterpSpec, ModeSortSpec
21+
from tidy3d.components.mode_spec import ModeInterpSpec, ModeSortSpec, ModeSpec
2222
from tidy3d.components.monitor import (
2323
AuxFieldTimeMonitor,
2424
DiffractionMonitor,
@@ -45,6 +45,7 @@
4545
ArrayFloat1D,
4646
ArrayFloat2D,
4747
Coordinate,
48+
Direction,
4849
EMField,
4950
EpsSpecType,
5051
FreqArray,
@@ -79,6 +80,7 @@
7980
MixedModeDataArray,
8081
ModeAmpsDataArray,
8182
ModeDispersionDataArray,
83+
ModeIndexDataArray,
8284
ScalarFieldDataArray,
8385
ScalarFieldTimeDataArray,
8486
TimeDataArray,
@@ -2458,6 +2460,36 @@ class ModeSolverData(ModeData):
24582460
None, title="Amplitudes", description="Unused for ModeSolverData."
24592461
)
24602462

2463+
grid_distances_primal: Union[tuple[float], tuple[float, float]] = pd.Field(
2464+
(0.0,),
2465+
title="Distances to the Primal Grid",
2466+
description="Relative distances to the primal grid locations along the normal direction in "
2467+
"the original simulation grid. Needed to recalculate grid corrections after "
2468+
"interpolating in frequency.",
2469+
)
2470+
2471+
grid_distances_dual: Union[tuple[float], tuple[float, float]] = pd.Field(
2472+
(0.0,),
2473+
title="Distances to the Dual Grid",
2474+
description="Relative distances to the dual grid locations along the normal direction in "
2475+
"the original simulation grid. Needed to recalculate grid corrections after "
2476+
"interpolating in frequency.",
2477+
)
2478+
2479+
@pd.validator("eps_spec", always=True)
2480+
@skip_if_fields_missing(["monitor"])
2481+
def eps_spec_match_mode_spec(cls, val, values):
2482+
"""Raise validation error if frequencies in eps_spec does not match frequency list"""
2483+
if val:
2484+
mnt = values["monitor"]
2485+
if (mnt.reduce_data and len(val) != mnt.mode_spec.interp_spec.num_points) or (
2486+
not mnt.reduce_data and len(val) != len(mnt.freqs)
2487+
):
2488+
raise ValidationError(
2489+
"eps_spec must be provided at the same frequencies as mode solver data."
2490+
)
2491+
return val
2492+
24612493
def normalize(self, source_spectrum_fn: Callable[[float], complex]) -> ModeSolverData:
24622494
"""Return copy of self after normalization is applied using source spectrum function."""
24632495
return self.copy()
@@ -2468,6 +2500,63 @@ def _normalize_modes(self):
24682500
for field in self.field_components.values():
24692501
field /= scaling
24702502

2503+
@staticmethod
2504+
def _grid_correction_factors(
2505+
primal_distances: tuple[float, ...],
2506+
dual_distances: tuple[float, ...],
2507+
mode_spec: ModeSpec,
2508+
n_complex: ModeIndexDataArray,
2509+
direction: Direction,
2510+
normal_dim: str,
2511+
) -> tuple[FreqModeDataArray, FreqModeDataArray]:
2512+
"""Calculate the grid correction factors for the primal and dual grid.
2513+
2514+
Parameters
2515+
----------
2516+
primal_distances : tuple[float, ...]
2517+
Relative distances to the primal grid locations along the normal direction in the original simulation grid.
2518+
dual_distances : tuple[float, ...]
2519+
Relative distances to the dual grid locations along the normal direction in the original simulation grid.
2520+
mode_spec : ModeSpec
2521+
Mode specification.
2522+
n_complex : ModeIndexDataArray
2523+
Effective indices of the modes.
2524+
direction : Direction
2525+
Direction of the propagation.
2526+
normal_dim : str
2527+
Name of the normal dimension.
2528+
2529+
Returns
2530+
-------
2531+
tuple[FreqModeDataArray, FreqModeDataArray]
2532+
Grid correction factors for the primal and dual grid.
2533+
"""
2534+
2535+
distances_primal = xr.DataArray(primal_distances, coords={normal_dim: primal_distances})
2536+
distances_dual = xr.DataArray(dual_distances, coords={normal_dim: dual_distances})
2537+
2538+
# Propagation phase at the primal and dual locations. The k-vector is along the propagation
2539+
# direction, so angle_theta has to be taken into account. The distance along the propagation
2540+
# direction is the distance along the normal direction over cosine(theta).
2541+
cos_theta = np.cos(mode_spec.angle_theta)
2542+
k_vec = cos_theta * 2 * np.pi * n_complex * n_complex.f / C_0
2543+
if direction == "-":
2544+
k_vec *= -1
2545+
phase_primal = np.exp(1j * k_vec * distances_primal)
2546+
phase_dual = np.exp(1j * k_vec * distances_dual)
2547+
2548+
# Fields are modified by a linear interpolation to the exact monitor position
2549+
if distances_primal.size > 1:
2550+
phase_primal = phase_primal.interp(**{normal_dim: 0})
2551+
else:
2552+
phase_primal = phase_primal.squeeze(dim=normal_dim)
2553+
if distances_dual.size > 1:
2554+
phase_dual = phase_dual.interp(**{normal_dim: 0})
2555+
else:
2556+
phase_dual = phase_dual.squeeze(dim=normal_dim)
2557+
2558+
return FreqModeDataArray(phase_primal), FreqModeDataArray(phase_dual)
2559+
24712560
@staticmethod
24722561
def _validate_cheb_nodes(freqs: np.ndarray) -> None:
24732562
"""Validate that frequencies are approximately at Chebyshev nodes.
@@ -2504,7 +2593,8 @@ def interp_in_freq(
25042593
self,
25052594
freqs: FreqArray,
25062595
method: Literal["linear", "cubic", "cheb"] = "linear",
2507-
renormalize: Optional[bool] = True,
2596+
renormalize: bool = True,
2597+
recalculate_grid_correction: bool = True,
25082598
) -> ModeSolverData:
25092599
"""Interpolate mode data to new frequency points.
25102600
@@ -2524,10 +2614,10 @@ def interp_in_freq(
25242614
frequencies), ``"cheb"`` for Chebyshev polynomial interpolation using barycentric
25252615
formula (requires 3+ source frequencies at Chebyshev nodes).
25262616
For complex-valued data, real and imaginary parts are interpolated independently.
2527-
renormalize : Optional[bool] = True
2617+
renormalize : bool = True
25282618
Whether to renormalize the mode profiles to unity power after interpolation.
2529-
recalculate_grid_correction : Optional[bool] = True
2530-
Whether to recalculate the grid correction after interpolation or use interpolated
2619+
recalculate_grid_correction : bool = True
2620+
Whether to recalculate the grid correction factors after interpolation or use interpolated
25312621
grid corrections.
25322622
25332623
Returns
@@ -2563,7 +2653,11 @@ def interp_in_freq(
25632653
"""
25642654
# Validate input
25652655
freqs = np.array(freqs)
2656+
25662657
source_freqs = np.array(self.monitor.freqs)
2658+
if self.monitor.reduce_data:
2659+
# it is validated that if reduce_data is True, then interp_spec is not None
2660+
source_freqs = self.monitor.mode_spec.interp_spec.sampling_points(source_freqs)
25672661

25682662
# Validate method-specific requirements
25692663
if method == "cubic" and len(source_freqs) < 4:
@@ -2607,14 +2701,30 @@ def interp_in_freq(
26072701
if self.eps_spec is not None:
26082702
update_dict["eps_spec"] = list(
26092703
self._interp_dataarray_in_freq(
2610-
FreqDataArray(self.eps_spec, coords={"f": self.monitor.freqs}),
2704+
FreqDataArray(self.eps_spec, coords={"f": source_freqs}),
26112705
freqs,
26122706
"nearest",
26132707
).data
26142708
)
26152709

2616-
# Update monitor with new frequencies
2617-
update_dict["monitor"] = self.monitor.updated_copy(freqs=list(freqs))
2710+
# Update monitor with new frequencies, remove interp_spec and set reduce_data to False
2711+
update_dict["monitor"] = self.monitor.updated_copy(
2712+
freqs=list(freqs),
2713+
mode_spec=self.monitor.mode_spec.updated_copy(interp_spec=None),
2714+
reduce_data=False,
2715+
)
2716+
2717+
if recalculate_grid_correction:
2718+
update_dict["grid_primal_correction"], update_dict["grid_dual_correction"] = (
2719+
self._grid_correction_factors(
2720+
list(self.grid_distances_primal),
2721+
list(self.grid_distances_dual),
2722+
self.monitor.mode_spec,
2723+
update_dict["n_complex"],
2724+
self.monitor.direction,
2725+
"xyz"[self.monitor._normal_axis],
2726+
)
2727+
)
26182728

26192729
updated_data = self.updated_copy(**update_dict)
26202730
if renormalize:
@@ -2627,12 +2737,13 @@ def interpolated_copy(self) -> ModeSolverData:
26272737
"""Return a copy of the data with interpolated fields."""
26282738
if self.monitor.mode_spec.interp_spec is None or not self.monitor.reduce_data:
26292739
return self
2630-
return self.interp_in_freq(
2740+
interpolated_data = self.interp_in_freq(
26312741
freqs=self.monitor.freqs,
26322742
method=self.monitor.mode_spec.interp_spec.method,
26332743
renormalize=True,
2634-
monitor=self.monitor.updated_copy(reduce_data=False),
2744+
recalculate_grid_correction=True,
26352745
)
2746+
return interpolated_data.updated_copy(monitor=self.monitor.updated_copy(reduce_data=False))
26362747

26372748
@property
26382749
def time_reversed_copy(self) -> FieldData:

tidy3d/components/mode/mode_solver.py

Lines changed: 56 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -537,14 +537,17 @@ def _get_data_with_interp(self) -> ModeSolverData:
537537
# Get data at reduced frequencies
538538
data_reduced = mode_solver_reduced.data_raw
539539

540+
# restore original mode_spec
541+
data_reduced = data_reduced.updated_copy(
542+
monitor=data_reduced.monitor.updated_copy(
543+
freqs=self.freqs, mode_spec=self.mode_spec, reduce_data=True
544+
)
545+
)
546+
540547
if self.reduce_data:
541548
return data_reduced
542549

543-
return data_reduced.interp_in_freq(
544-
freqs=self.freqs,
545-
method=self.mode_spec.interp_spec.method,
546-
renormalize=True,
547-
)
550+
return data_reduced.interpolated_copy
548551

549552
@cached_property
550553
def grid_snapped(self) -> Grid:
@@ -670,7 +673,7 @@ def rotated_mode_solver_data(self) -> ModeSolverData:
670673
for _ in self.freqs:
671674
eps_spec.append("tensorial_complex")
672675
# finite grid corrections
673-
grid_factors = solver._grid_correction(
676+
grid_factors, relative_grid_distances = solver._grid_correction(
674677
simulation=solver.simulation,
675678
plane=solver.plane,
676679
mode_spec=solver.mode_spec,
@@ -689,6 +692,8 @@ def rotated_mode_solver_data(self) -> ModeSolverData:
689692
grid_primal_correction=grid_factors[0],
690693
grid_dual_correction=grid_factors[1],
691694
eps_spec=eps_spec,
695+
grid_distances_primal=relative_grid_distances[0],
696+
grid_distances_dual=relative_grid_distances[1],
692697
**rotated_mode_fields,
693698
)
694699

@@ -1257,7 +1262,7 @@ def _data_on_yee_grid(self) -> ModeSolverData:
12571262
data_dict[field_name] = scalar_field_data
12581263

12591264
# finite grid corrections
1260-
grid_factors = solver._grid_correction(
1265+
grid_factors, relative_grid_distances = solver._grid_correction(
12611266
simulation=solver.simulation,
12621267
plane=solver.plane,
12631268
mode_spec=solver.mode_spec,
@@ -1275,6 +1280,8 @@ def _data_on_yee_grid(self) -> ModeSolverData:
12751280
grid_expanded=grid_expanded,
12761281
grid_primal_correction=grid_factors[0],
12771282
grid_dual_correction=grid_factors[1],
1283+
grid_distances_primal=relative_grid_distances[0],
1284+
grid_distances_dual=relative_grid_distances[1],
12781285
eps_spec=eps_spec,
12791286
**data_dict,
12801287
)
@@ -1329,7 +1336,7 @@ def _data_on_yee_grid_relative(self, basis: ModeSolverData) -> ModeSolverData:
13291336
data_dict[field_name] = scalar_field_data
13301337

13311338
# finite grid corrections
1332-
grid_factors = self._grid_correction(
1339+
grid_factors, relative_grid_distances = self._grid_correction(
13331340
simulation=self.simulation,
13341341
plane=self.plane,
13351342
mode_spec=self.mode_spec,
@@ -1347,6 +1354,8 @@ def _data_on_yee_grid_relative(self, basis: ModeSolverData) -> ModeSolverData:
13471354
grid_expanded=grid_expanded,
13481355
grid_primal_correction=grid_factors[0],
13491356
grid_dual_correction=grid_factors[1],
1357+
grid_distances_primal=relative_grid_distances[0],
1358+
grid_distances_dual=relative_grid_distances[1],
13501359
eps_spec=eps_spec,
13511360
**data_dict,
13521361
)
@@ -1892,7 +1901,9 @@ def _grid_correction(
18921901
mode_spec: ModeSpec,
18931902
n_complex: ModeIndexDataArray,
18941903
direction: Direction,
1895-
) -> [FreqModeDataArray, FreqModeDataArray]:
1904+
) -> tuple[
1905+
tuple[FreqModeDataArray, FreqModeDataArray], tuple[tuple[float, ...], tuple[float, ...]]
1906+
]:
18961907
"""
18971908
Compute grid correction factors for the mode fields.
18981909
@@ -1935,27 +1946,43 @@ def _grid_correction(
19351946
normal_dual = grid.centers.to_list[normal_axis]
19361947
normal_dual = xr.DataArray(normal_dual, coords={normal_dim: normal_dual})
19371948

1938-
# Propagation phase at the primal and dual locations. The k-vector is along the propagation
1939-
# direction, so angle_theta has to be taken into account. The distance along the propagation
1940-
# direction is the distance along the normal direction over cosine(theta).
1941-
cos_theta = np.cos(mode_spec.angle_theta)
1942-
k_vec = cos_theta * 2 * np.pi * n_complex * n_complex.f / C_0
1943-
if direction == "-":
1944-
k_vec *= -1
1945-
phase_primal = np.exp(1j * k_vec * (normal_primal - normal_pos))
1946-
phase_dual = np.exp(1j * k_vec * (normal_dual - normal_pos))
1947-
1948-
# Fields are modified by a linear interpolation to the exact monitor position
1949-
if normal_primal.size > 1:
1950-
phase_primal = phase_primal.interp(**{normal_dim: normal_pos})
1951-
else:
1952-
phase_primal = phase_primal.squeeze(dim=normal_dim)
1953-
if normal_dual.size > 1:
1954-
phase_dual = phase_dual.interp(**{normal_dim: normal_pos})
1955-
else:
1956-
phase_dual = phase_dual.squeeze(dim=normal_dim)
1949+
def find_closest_distances_to_grid_points(
1950+
normal_pos: float, grid_coords: ArrayFloat1D
1951+
) -> tuple[float, float]:
1952+
"""Find the closest points to the normal position in the grid coordinates."""
1953+
1954+
if grid_coords.size == 1:
1955+
return [float(grid_coords.data[0] - normal_pos)]
1956+
1957+
distances = grid_coords.data - normal_pos
1958+
# First, find the signed distance to the closest grid point
1959+
closest_distance_ind = np.argmin(np.abs(distances))
1960+
closest_distance = distances[closest_distance_ind]
1961+
1962+
# Then, if the closest distance is positive, take the previous point, otherwise take the next point
1963+
if closest_distance > 0:
1964+
first_dist = distances[closest_distance_ind - 1]
1965+
second_dist = distances[closest_distance_ind]
1966+
else:
1967+
first_dist = distances[closest_distance_ind]
1968+
second_dist = distances[closest_distance_ind + 1]
1969+
1970+
# Return the two closest points
1971+
return [first_dist, second_dist]
1972+
1973+
primal_closest_distances = find_closest_distances_to_grid_points(normal_pos, normal_primal)
1974+
dual_closest_distances = find_closest_distances_to_grid_points(normal_pos, normal_dual)
1975+
1976+
grid_correction_factors = ModeSolverData._grid_correction_factors(
1977+
primal_closest_distances,
1978+
dual_closest_distances,
1979+
mode_spec,
1980+
n_complex,
1981+
direction,
1982+
normal_dim,
1983+
)
19571984

1958-
return FreqModeDataArray(phase_primal), FreqModeDataArray(phase_dual)
1985+
return grid_correction_factors, (primal_closest_distances, dual_closest_distances)
19591986

19601987
@property
19611988
def _is_tensorial(self) -> bool:

0 commit comments

Comments
 (0)