99
1010try :
1111 from matplotlib import pyplot as plt
12- from matplotlib .tri import Triangulation
1312except ImportError :
1413 pass
1514
15+ from matplotlib import colormaps
16+ from matplotlib .colors import Normalize
1617from xarray import DataArray as XrDataArray
1718
1819from tidy3d .components .base import cached_property
1920from tidy3d .components .data .data_array import (
2021 CellDataArray ,
21- IndexedDataArray ,
22+ IndexedDataArrayTypes ,
2223 PointDataArray ,
23- SpatialDataArray ,
24- IndexedDataArrayTypes
2524)
26- from tidy3d .components .types import ArrayLike , Ax , Axis , Bound
27- from tidy3d .components .viz import add_ax_3d_if_none , equal_aspect , plot_params_grid
28- from tidy3d .constants import inf
25+ from tidy3d .components .types import ArrayLike , Ax , Axis
26+ from tidy3d .components .viz import add_ax_3d_if_none , equal_aspect
2927from tidy3d .exceptions import DataError , Tidy3dNotImplementedError
30- from tidy3d .log import log
3128from tidy3d .packaging import requires_vtk , vtk
3229
3330from .base import (
3431 UnstructuredDataset ,
3532)
36- from matplotlib .colors import Normalize
37- from matplotlib import colormaps
3833
3934
4035class TriangularSurfaceDataset (UnstructuredDataset ):
@@ -148,8 +143,8 @@ def sel(
148143 method : Literal ["None" , "nearest" , "pad" , "ffill" , "backfill" , "bfill" ] = None ,
149144 ** sel_kwargs ,
150145 ) -> Union [TriangularSurfaceDataset , XrDataArray ]:
151- """Extract/interpolate data along one or more spatial or non-spatial directions.
152- Currently works only for non-spatial dimensions through additional arguments.
146+ """Extract/interpolate data along one or more spatial or non-spatial directions.
147+ Currently works only for non-spatial dimensions through additional arguments.
153148 Selection along non-spatial dimensions is forwarded to
154149 .sel() xarray function. Parameter 'method' applies only to non-spatial dimensions.
155150
@@ -173,18 +168,19 @@ def sel(
173168 """
174169
175170 if any (comp is not None for comp in [x , y , z ]):
176- raise Tidy3dNotImplementedError ("Surface datasets do not support selection along x, y, or z yet." )
177-
171+ raise Tidy3dNotImplementedError (
172+ "Surface datasets do not support selection along x, y, or z yet."
173+ )
174+
178175 return self ._non_spatial_sel (method = method , ** sel_kwargs )
179-
176+
180177 def get_cell_volumes (self ):
181178 """Get areas associated to each cell of the grid."""
182179 v0 = self .points [self .cells .sel (vertex_index = 0 )]
183180 e01 = self .points [self .cells .sel (vertex_index = 1 )] - v0
184181 e02 = self .points [self .cells .sel (vertex_index = 2 )] - v0
185182
186183 return 0.5 * np .abs (np .cross (e01 , e02 ))
187-
188184
189185 """ Plotting """
190186
@@ -246,7 +242,7 @@ def plot(
246242 "Use '.sel()' to select a single field from available dimensions "
247243 f"{ self ._values_coords_dict } before plotting."
248244 )
249-
245+
250246 face_colors = None
251247 face_alpha = 0
252248 edge_colors = None
@@ -256,16 +252,16 @@ def plot(
256252 values_avg = np .mean (self .values .data .ravel ()[self .cells .data ], axis = 1 )
257253 face_colors = colormaps [cmap ](norm (values_avg ))
258254 face_alpha = 1
259-
255+
260256 if grid :
261257 edge_colors = "k"
262258
263259 plot_obj = ax .plot_trisurf (
264- self .points .data [:, 0 ],
265- self .points .data [:, 1 ],
266- self .points .data [:, 2 ],
267- triangles = self .cells .data ,
268- fc = face_colors ,
260+ self .points .data [:, 0 ],
261+ self .points .data [:, 1 ],
262+ self .points .data [:, 2 ],
263+ triangles = self .cells .data ,
264+ fc = face_colors ,
269265 ec = edge_colors ,
270266 alpha = face_alpha ,
271267 # cmap=cmap,
@@ -283,7 +279,7 @@ def plot(
283279 if buffer is not None :
284280 bounds = np .array (self .bounds )
285281 size = np .linalg .norm (bounds [1 ] - bounds [0 ])
286-
282+
287283 ax .set_xlim (bounds [0 ][0 ] - buffer * size , bounds [1 ][0 ] + buffer * size )
288284 ax .set_ylim (bounds [0 ][1 ] - buffer * size , bounds [1 ][1 ] + buffer * size )
289285 ax .set_zlim (bounds [0 ][2 ] - buffer * size , bounds [1 ][2 ] + buffer * size )
@@ -310,7 +306,7 @@ def quiver(
310306 cbar_kwargs : Dict = None ,
311307 quiver_kwargs : Dict = None ,
312308 ) -> Ax :
313- """Plot the associated data as quiver plot. Field ``values`` must have length 3 along
309+ """Plot the associated data as quiver plot. Field ``values`` must have length 3 along
314310 the dimension representing x, y, and z components.
315311
316312 Parameters
@@ -352,7 +348,7 @@ def quiver(
352348 "Use '.sel()' to select a single field from available dimensions "
353349 f"{ self ._values_coords_dict } before plotting."
354350 )
355-
351+
356352 # compute max magnitude of vecotr field
357353 mag = np .sqrt (self .values .dot (self .values .conj (), dim = dim ).real )
358354 mag_max = np .max (mag )
@@ -364,19 +360,19 @@ def quiver(
364360 u = self .values .sel (** {dim : 0 }).real .data [::downsampling ] * scale_factor .data
365361 v = self .values .sel (** {dim : 1 }).real .data [::downsampling ] * scale_factor .data
366362 w = self .values .sel (** {dim : 2 }).real .data [::downsampling ] * scale_factor .data
367-
363+
368364 if color == "magnitude" :
369- clr = plt .colormaps [cmap ](1 - mag .data [::downsampling ].ravel ()/ mag_max .data )
365+ clr = plt .colormaps [cmap ](1 - mag .data [::downsampling ].ravel () / mag_max .data )
370366 else :
371367 clr = color
372368 plot_obj = ax .quiver (
373- self .points .sel (axis = 0 ).data [::downsampling ],
374- self .points .sel (axis = 1 ).data [::downsampling ],
369+ self .points .sel (axis = 0 ).data [::downsampling ],
370+ self .points .sel (axis = 1 ).data [::downsampling ],
375371 self .points .sel (axis = 2 ).data [::downsampling ],
376- u .ravel (),
377- v .ravel (),
378- w .ravel (),
379- color = clr ,
372+ u .ravel (),
373+ v .ravel (),
374+ w .ravel (),
375+ color = clr ,
380376 ** quiver_kwargs ,
381377 )
382378
@@ -390,7 +386,7 @@ def quiver(
390386 if buffer is not None :
391387 bounds = np .array (self .bounds )
392388 size = np .linalg .norm (bounds [1 ] - bounds [0 ])
393-
389+
394390 ax .set_xlim (bounds [0 ][0 ] - buffer * size , bounds [1 ][0 ] + buffer * size )
395391 ax .set_ylim (bounds [0 ][1 ] - buffer * size , bounds [1 ][1 ] + buffer * size )
396392 ax .set_zlim (bounds [0 ][2 ] - buffer * size , bounds [1 ][2 ] + buffer * size )
0 commit comments