Skip to content

Commit

Permalink
ENH: Add streamline concatenation script
Browse files Browse the repository at this point in the history
Add streamline concatenation script and the necessary methods into the
appropriate module.
  • Loading branch information
jhlegarreta committed Sep 14, 2023
1 parent 9485556 commit 4b75d3a
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 0 deletions.
106 changes: 106 additions & 0 deletions bin/wm_concatenate_streamlines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#!/usr/bin/env python

"""Concatenate streamlines contained in different tractograms. Point and cell
scalar and tensor data are not preserved.
"""

import os
import argparse

import vtk

import whitematteranalysis as wma


def pipeline(fnames, out_fname, verbose=False):

out_polydata = vtk.vtkPolyData()

# Get subject identifier from unique input filename
for index, fname in enumerate(fnames):
sub_id = os.path.splitext(os.path.basename(fname))[0]
id_msg = f"{os.path.basename(__file__)} {index + 1} / {len(fnames)}"
msg = f"**Starting subject: {sub_id}"
print(id_msg + msg)

# Read tractogram
msg = f"**Reading input: {sub_id}"
print(id_msg + msg)

polydata = wma.io.read_polydata(fname)
print(f"Number of streamlines: {polydata.GetNumberOfLines()}")

# Concatenate
out_polydata = wma.filter.concatenate_streamlines(
[out_polydata, polydata],
_verbose=verbose
)

print(f"Number of streamlines concatenated: {out_polydata.GetNumberOfLines()}")

# Output
try:
print(f"Writing output polydata {out_fname}...")
wma.io.write_polydata(out_polydata, out_fname)
print(f"Wrote output {out_fname}.")
except:
print("Unknown exception in IO")
raise


def _build_arg_parser():

parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument(
'inputDirectory',
help='Directory containing tractography files (.*vtk|.*vtp).')
parser.add_argument(
'outputFilename',
help='Output filename (.*vtk|.*vtp).')
parser.add_argument(
'-verbose', action='store_true', dest="flag_verbose",
help='Verbose. Run with -verbose to print operation information.')

return parser


def _parse_args(parser):

args = parser.parse_args()
return args


def main():
parser = _build_arg_parser()
args = _parse_args(parser)

if not os.path.isdir(args.inputDirectory):
print(f"Error: Input directory {args.inputDirectory} does not exist.")
exit()

out_fname = args.outputFilename
if os.path.exists(out_fname):
msg = f"Output file {out_fname} exists. Remove or rename the output file."
parser.error(msg)

print(f"{os.path.basename(__file__)}. Starting streamline concatenation.")
print("")
print("=====input directory======\n", args.inputDirectory)
print("=====output filename=====\n", args.outputFilename)
print("==========================")

verbose = args.flag_verbose

fnames = wma.io.list_vtk_files(args.inputDirectory)

print(f"<{os.path.basename(__file__)}> Input number of files: ", len(fnames))

pipeline(fnames, out_fname, verbose=verbose)

exit()


if __name__ == "__main__":
main()
58 changes: 58 additions & 0 deletions whitematteranalysis/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
"""

import os

import vtk
import numpy

Expand Down Expand Up @@ -250,6 +252,62 @@ def preprocess(inpd, min_length_mm,
else:
return outpd


def concatenate(polydatas, _verbose=False):
"""Concatenate a list of polydatas.
Parameters
----------
polydatas : list
vtkPolyData objects.
_verbose : bool, optional
True if processing needs to be printed.
Returns
-------
out_polydata : vtkPolyData
Concatenated polydata.
"""

out_polydata = vtk.vtkPolyData()
out_lines = vtk.vtkCellArray()
out_points = vtk.vtkPoints()

for polydata in polydatas:

# Loop over lines
polydata.GetLines().InitTraversal()
out_lines.InitTraversal()

in_points = polydata.GetPoints()

for line_id in range(polydata.GetNumberOfLines()):

point_ids = vtk.vtkIdList()
polydata.GetLines().GetNextCell(point_ids)

if _verbose:
if line_id % 100 == 0:
print(f"<{os.path.basename(__file__)}> Line: {line_id} / {polydata.GetNumberOfLines()}")

# Get points for each point id and add to output polydata
cell_point_ids = vtk.vtkIdList()

for point_id in range(point_ids.GetNumberOfIds()):

point = in_points.GetPoint(point_ids.GetId(point_id))
idx = out_points.InsertNextPoint(point)
cell_point_ids.InsertNextId(idx)

out_lines.InsertNextCell(cell_point_ids)

# Put data into output polydata
out_polydata.SetLines(out_lines)
out_polydata.SetPoints(out_points)

return out_polydata


def downsample(inpd, output_number_of_lines, return_indices=False, preserve_point_data=False, preserve_cell_data=True, initial_indices=None, verbose=True, random_seed=None):
""" Random (down)sampling of fibers without replacement. """

Expand Down

0 comments on commit 4b75d3a

Please sign in to comment.