Skip to content

Commit 9bdb814

Browse files
authored
Merge pull request #253 from NeuroML/feat/hybrid-plots
Feat/hybrid plots: allow users to specify how they want various cells in networks to be plotted
2 parents 0899741 + 03f9475 commit 9bdb814

File tree

6 files changed

+352
-68
lines changed

6 files changed

+352
-68
lines changed

pyneuroml/plot/PlotMorphology.py

Lines changed: 107 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,17 @@ def process_args():
7575
help="Plane to plot on for 2D plot",
7676
)
7777

78+
parser.add_argument(
79+
"-pointFraction",
80+
type=str,
81+
metavar="<fraction of each population to plot as point cells>",
82+
default=DEFAULTS["pointFraction"],
83+
help="Fraction of network to plot as point cells",
84+
)
7885
parser.add_argument(
7986
"-plotType",
8087
type=str,
81-
metavar="<type: detailed, constant, or schematic>",
88+
metavar="<type: detailed, constant, schematic, or point>",
8289
default=DEFAULTS["plotType"],
8390
help="Level of detail to plot in",
8491
)
@@ -147,6 +154,7 @@ def plot_from_console(a: typing.Optional[typing.Any] = None, **kwargs: str):
147154
verbose=a.v,
148155
plot_type=a.plot_type,
149156
theme=a.theme,
157+
plot_spec={"point_fraction": a.point_fraction},
150158
)
151159
else:
152160
plot_2D(
@@ -158,6 +166,7 @@ def plot_from_console(a: typing.Optional[typing.Any] = None, **kwargs: str):
158166
a.save_to_file,
159167
a.square,
160168
a.plot_type,
169+
plot_spec={"point_fraction": a.point_fraction},
161170
)
162171

163172

@@ -172,6 +181,9 @@ def plot_2D(
172181
plot_type: str = "detailed",
173182
title: typing.Optional[str] = None,
174183
close_plot: bool = False,
184+
plot_spec: typing.Optional[
185+
typing.Dict[str, typing.Union[str, typing.List[int], float]]
186+
] = None,
175187
):
176188
"""Plot cells in a 2D plane.
177189
@@ -205,6 +217,7 @@ def plot_2D(
205217
- "constant": show morphology, but use constant line widths
206218
- "schematic": only plot each unbranched segment group as a straight
207219
line, not following each segment
220+
- "point": show all cells as points
208221
209222
This is only applicable for neuroml.Cell cells (ones with some
210223
morphology)
@@ -214,20 +227,36 @@ def plot_2D(
214227
:type title: str
215228
:param close_plot: call pyplot.close() to close plot after plotting
216229
:type close_plot: bool
230+
:param plot_spec: dictionary that allows passing some specifications that
231+
control how a plot is generated. This is mostly useful for large
232+
network plots where one may want to have a mix of full morphology and
233+
schematic, and point representations of cells. Possible keys are:
234+
235+
- point_fraction: what fraction of each population to plot as point cells:
236+
these cells will be randomly selected
237+
- points_cells: list of cell ids to plot as point cells
238+
- schematic_cells: list of cell ids to plot as schematics
239+
- constant_cells: list of cell ids to plot as constant widths
240+
241+
The last three lists override the point_fraction setting. If a cell id
242+
is not included in the spec here, it will follow the plot_type provided
243+
before.
217244
"""
218245

219-
if plot_type not in ["detailed", "constant", "schematic"]:
246+
if plot_type not in ["detailed", "constant", "schematic", "point"]:
220247
raise ValueError(
221-
"plot_type must be one of 'detailed', 'constant', or 'schematic'"
248+
"plot_type must be one of 'detailed', 'constant', 'schematic', 'point'"
222249
)
223250

224251
if verbose:
225252
print("Plotting %s" % nml_file)
226253

227-
if type(nml_file) == str:
254+
# do not recursive read the file, the extract_position_info function will
255+
# do that for us, from a copy of the model
256+
if type(nml_file) is str:
228257
nml_model = read_neuroml2_file(
229258
nml_file,
230-
include_includes=True,
259+
include_includes=False,
231260
check_validity_pre_include=False,
232261
verbose=False,
233262
optimized=True,
@@ -250,7 +279,9 @@ def plot_2D(
250279
positions,
251280
pop_id_vs_color,
252281
pop_id_vs_radii,
253-
) = extract_position_info(nml_model, verbose)
282+
) = extract_position_info(
283+
nml_model, verbose, nml_file if type(nml_file) is str else ""
284+
)
254285

255286
if title is None:
256287
if len(nml_model.networks) > 0:
@@ -268,12 +299,45 @@ def plot_2D(
268299
fig, ax = get_new_matplotlib_morph_plot(title, plane2d)
269300
axis_min_max = [float("inf"), -1 * float("inf")]
270301

271-
for pop_id in pop_id_vs_cell:
272-
cell = pop_id_vs_cell[pop_id]
273-
pos_pop = positions[pop_id]
302+
# process plot_spec
303+
point_cells = [] # type: typing.List[int]
304+
schematic_cells = [] # type: typing.List[int]
305+
constant_cells = [] # type: typing.List[int]
306+
detailed_cells = [] # type: typing.List[int]
307+
if plot_spec is not None:
308+
try:
309+
point_cells = plot_spec["point_cells"]
310+
except KeyError:
311+
pass
312+
try:
313+
schematic_cells = plot_spec["schematic_cells"]
314+
except KeyError:
315+
pass
316+
try:
317+
constant_cells = plot_spec["constant_cells"]
318+
except KeyError:
319+
pass
320+
try:
321+
detailed_cells = plot_spec["detailed_cells"]
322+
except KeyError:
323+
pass
324+
325+
for pop_id, cell in pop_id_vs_cell.items():
326+
pos_pop = positions[pop_id] # type: typing.Dict[typing.Any, typing.List[float]]
327+
328+
# reinit point_cells for each loop
329+
point_cells_pop = []
330+
if len(point_cells) == 0 and plot_spec is not None:
331+
cell_indices = list(pos_pop.keys())
332+
try:
333+
point_cells_pop = random.sample(
334+
cell_indices,
335+
int(len(cell_indices) * float(plot_spec["point_fraction"])),
336+
)
337+
except KeyError:
338+
pass
274339

275-
for cell_index in pos_pop:
276-
pos = pos_pop[cell_index]
340+
for cell_index, pos in pos_pop.items():
277341
radius = pop_id_vs_radii[pop_id] if pop_id in pop_id_vs_radii else 10
278342
color = pop_id_vs_color[pop_id] if pop_id in pop_id_vs_color else None
279343

@@ -291,12 +355,36 @@ def plot_2D(
291355
nogui=True,
292356
)
293357
else:
294-
if plot_type == "schematic":
358+
if (
359+
plot_type == "point"
360+
or cell_index in point_cells_pop
361+
or cell.id in point_cells
362+
):
363+
# assume that soma is 0, plot point at where soma should be
364+
soma_x_y_z = cell.get_actual_proximal(0)
365+
pos1 = [
366+
pos[0] + soma_x_y_z.x,
367+
pos[1] + soma_x_y_z.y,
368+
pos[2] + soma_x_y_z.z,
369+
]
370+
plot_2D_point_cells(
371+
offset=pos1,
372+
plane2d=plane2d,
373+
color=color,
374+
soma_radius=radius,
375+
verbose=verbose,
376+
ax=ax,
377+
fig=fig,
378+
autoscale=False,
379+
scalebar=False,
380+
nogui=True,
381+
)
382+
elif plot_type == "schematic" or cell.id in schematic_cells:
295383
plot_2D_schematic(
296384
offset=pos,
297385
cell=cell,
298386
segment_groups=None,
299-
labels=True,
387+
labels=False,
300388
plane2d=plane2d,
301389
verbose=verbose,
302390
fig=fig,
@@ -306,7 +394,12 @@ def plot_2D(
306394
autoscale=False,
307395
square=False,
308396
)
309-
else:
397+
elif (
398+
plot_type == "detailed"
399+
or cell.id in detailed_cells
400+
or plot_type == "constant"
401+
or cell.id in constant_cells
402+
):
310403
plot_2D_cell_morphology(
311404
offset=pos,
312405
cell=cell,

pyneuroml/plot/PlotMorphologyVispy.py

Lines changed: 94 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,17 @@
1111

1212

1313
import logging
14-
import typing
15-
import numpy
14+
import random
1615
import textwrap
17-
from vispy import scene, app
16+
import typing
1817

19-
from pyneuroml.utils.plot import (
20-
DEFAULTS,
21-
get_cell_bound_box,
22-
get_next_hex_color,
23-
)
18+
import numpy
19+
from neuroml import Cell, NeuroMLDocument, Segment, SegmentGroup
20+
from neuroml.neuro_lex_ids import neuro_lex_ids
2421
from pyneuroml.pynml import read_neuroml2_file
2522
from pyneuroml.utils import extract_position_info
26-
27-
from neuroml import Cell, NeuroMLDocument, SegmentGroup, Segment
28-
from neuroml.neuro_lex_ids import neuro_lex_ids
29-
23+
from pyneuroml.utils.plot import DEFAULTS, get_cell_bound_box, get_next_hex_color
24+
from vispy import app, scene
3025

3126
logger = logging.getLogger(__name__)
3227
logger.setLevel(logging.INFO)
@@ -296,6 +291,9 @@ def plot_interactive_3D(
296291
title: typing.Optional[str] = None,
297292
theme: str = "light",
298293
nogui: bool = False,
294+
plot_spec: typing.Optional[
295+
typing.Dict[str, typing.Union[str, typing.List[int], float]]
296+
] = None,
299297
):
300298
"""Plot interactive plots in 3D using Vispy
301299
@@ -316,6 +314,7 @@ def plot_interactive_3D(
316314
- "constant": show morphology, but use constant line widths
317315
- "schematic": only plot each unbranched segment group as a straight
318316
line, not following each segment
317+
- "point": show all cells as points
319318
320319
This is only applicable for neuroml.Cell cells (ones with some
321320
morphology)
@@ -327,19 +326,33 @@ def plot_interactive_3D(
327326
:type theme: str
328327
:param nogui: toggle showing gui (for testing only)
329328
:type nogui: bool
329+
:param plot_spec: dictionary that allows passing some specifications that
330+
control how a plot is generated. This is mostly useful for large
331+
network plots where one may want to have a mix of full morphology and
332+
schematic, and point representations of cells. Possible keys are:
333+
334+
- point_fraction: what fraction of each population to plot as point cells:
335+
these cells will be randomly selected
336+
- points_cells: list of cell ids to plot as point cells
337+
- schematic_cells: list of cell ids to plot as schematics
338+
- constant_cells: list of cell ids to plot as constant widths
339+
340+
The last three lists override the point_fraction setting. If a cell id
341+
is not included in the spec here, it will follow the plot_type provided
342+
before.
330343
"""
331-
if plot_type not in ["detailed", "constant", "schematic"]:
344+
if plot_type not in ["detailed", "constant", "schematic", "point"]:
332345
raise ValueError(
333-
"plot_type must be one of 'detailed', 'constant', or 'schematic'"
346+
"plot_type must be one of 'detailed', 'constant', 'schematic', 'point'"
334347
)
335348

336349
if verbose:
337350
print(f"Plotting {nml_file}")
338351

339-
if type(nml_file) == str:
352+
if type(nml_file) is str:
340353
nml_model = read_neuroml2_file(
341354
nml_file,
342-
include_includes=True,
355+
include_includes=False,
343356
check_validity_pre_include=False,
344357
verbose=False,
345358
optimized=True,
@@ -360,7 +373,9 @@ def plot_interactive_3D(
360373
positions,
361374
pop_id_vs_color,
362375
pop_id_vs_radii,
363-
) = extract_position_info(nml_model, verbose)
376+
) = extract_position_info(
377+
nml_model, verbose, nml_file if type(nml_file) is str else ""
378+
)
364379

365380
# Collect all markers and only plot one markers object
366381
# this is more efficient than multiple markers, one for each point.
@@ -429,12 +444,45 @@ def plot_interactive_3D(
429444

430445
logger.debug(f"figure extents are: {view_min}, {view_max}")
431446

432-
for pop_id in pop_id_vs_cell:
433-
cell = pop_id_vs_cell[pop_id]
434-
pos_pop = positions[pop_id]
447+
# process plot_spec
448+
point_cells = [] # type: typing.List[int]
449+
schematic_cells = [] # type: typing.List[int]
450+
constant_cells = [] # type: typing.List[int]
451+
detailed_cells = [] # type: typing.List[int]
452+
if plot_spec is not None:
453+
try:
454+
point_cells = plot_spec["point_cells"]
455+
except KeyError:
456+
pass
457+
try:
458+
schematic_cells = plot_spec["schematic_cells"]
459+
except KeyError:
460+
pass
461+
try:
462+
constant_cells = plot_spec["constant_cells"]
463+
except KeyError:
464+
pass
465+
try:
466+
detailed_cells = plot_spec["detailed_cells"]
467+
except KeyError:
468+
pass
469+
470+
for pop_id, cell in pop_id_vs_cell.items():
471+
pos_pop = positions[pop_id] # type: typing.Dict[typing.Any, typing.List[float]]
435472

436-
for cell_index in pos_pop:
437-
pos = pos_pop[cell_index]
473+
# reinit point_cells for each loop
474+
point_cells_pop = []
475+
if len(point_cells) == 0 and plot_spec is not None:
476+
cell_indices = list(pos_pop.keys())
477+
try:
478+
point_cells_pop = random.sample(
479+
cell_indices,
480+
int(len(cell_indices) * float(plot_spec["point_fraction"])),
481+
)
482+
except KeyError:
483+
pass
484+
485+
for cell_index, pos in pos_pop.items():
438486
radius = pop_id_vs_radii[pop_id] if pop_id in pop_id_vs_radii else 10
439487
color = pop_id_vs_color[pop_id] if pop_id in pop_id_vs_color else None
440488

@@ -448,7 +496,24 @@ def plot_interactive_3D(
448496
marker_sizes.extend([radius])
449497
marker_colors.extend([color])
450498
else:
451-
if plot_type == "schematic":
499+
if (
500+
plot_type == "point"
501+
or cell_index in point_cells_pop
502+
or cell.id in point_cells
503+
):
504+
# assume that soma is 0, plot point at where soma should be
505+
soma_x_y_z = cell.get_actual_proximal(0)
506+
pos1 = [
507+
pos[0] + soma_x_y_z.x,
508+
pos[1] + soma_x_y_z.y,
509+
pos[2] + soma_x_y_z.z,
510+
]
511+
marker_points.extend([pos1])
512+
# larger than the default soma width, which would be too
513+
# small
514+
marker_sizes.extend([25])
515+
marker_colors.extend([color])
516+
elif plot_type == "schematic" or cell.id in schematic_cells:
452517
plot_3D_schematic(
453518
offset=pos,
454519
cell=cell,
@@ -459,7 +524,12 @@ def plot_interactive_3D(
459524
current_view=current_view,
460525
nogui=True,
461526
)
462-
else:
527+
elif (
528+
plot_type == "detailed"
529+
or cell.id in detailed_cells
530+
or plot_type == "constant"
531+
or cell.id in constant_cells
532+
):
463533
pts, sizes, colors = plot_3D_cell_morphology(
464534
offset=pos,
465535
cell=cell,

0 commit comments

Comments
 (0)