Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add streamline concatenation script #150

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions bin/wm_concatenate_streamlines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""Concatenate streamlines contained in different tractograms. Point and cell
scalar and tensor data are not preserved.
"""

import os
import argparse

import vtk

import whitematteranalysis as wma


def pipeline(fnames, out_fname, verbose=False):

out_polydata = vtk.vtkPolyData()

# Get subject identifier from unique input filename
for index, fname in enumerate(fnames):
sub_id = os.path.splitext(os.path.basename(fname))[0]
id_msg = f"{os.path.basename(__file__)} {index + 1} / {len(fnames)}"
msg = f"**Starting subject: {sub_id}"
print(id_msg + msg)

# Read tractogram
msg = f"**Reading input: {sub_id}"
print(id_msg + msg)

polydata = wma.io.read_polydata(fname)
print(f"Number of streamlines: {polydata.GetNumberOfLines()}")

# Concatenate
out_polydata = wma.filter.concatenate_streamlines(
[out_polydata, polydata],
_verbose=verbose
)

print(f"Number of streamlines concatenated: {out_polydata.GetNumberOfLines()}")

# Output
try:
print(f"Writing output polydata {out_fname}...")
wma.io.write_polydata(out_polydata, out_fname)
print(f"Wrote output {out_fname}.")
except:
print("Unknown exception in IO")
raise


def _build_arg_parser():

parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument(
'inputDirectory',
help='Directory containing tractography files (.*vtk|.*vtp).')
parser.add_argument(
'outputFilename',
help='Output filename (.*vtk|.*vtp).')
parser.add_argument(
'-verbose', action='store_true', dest="flag_verbose",
help='Verbose. Run with -verbose to print operation information.')

return parser


def _parse_args(parser):

args = parser.parse_args()
return args


def main():
parser = _build_arg_parser()
args = _parse_args(parser)

if not os.path.isdir(args.inputDirectory):
print(f"Error: Input directory {args.inputDirectory} does not exist.")
exit()

out_fname = args.outputFilename
if os.path.exists(out_fname):
msg = f"Output file {out_fname} exists. Remove or rename the output file."
parser.error(msg)

print(f"{os.path.basename(__file__)}. Starting streamline concatenation.")
print("")
print("=====input directory======\n", args.inputDirectory)
print("=====output filename=====\n", args.outputFilename)
print("==========================")

verbose = args.flag_verbose

fnames = wma.io.list_vtk_files(args.inputDirectory)

print(f"<{os.path.basename(__file__)}> Input number of files: ", len(fnames))

pipeline(fnames, out_fname, verbose=verbose)

exit()


if __name__ == "__main__":
main()
58 changes: 58 additions & 0 deletions whitematteranalysis/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

"""

import os

import vtk
import numpy

Expand Down Expand Up @@ -250,6 +252,62 @@ def preprocess(inpd, min_length_mm,
else:
return outpd


def concatenate(polydatas, _verbose=False):
"""Concatenate a list of polydatas.

Parameters
----------
polydatas : list
vtkPolyData objects.
_verbose : bool, optional
True if processing needs to be printed.

Returns
-------
out_polydata : vtkPolyData
Concatenated polydata.
"""

out_polydata = vtk.vtkPolyData()
out_lines = vtk.vtkCellArray()
out_points = vtk.vtkPoints()

for polydata in polydatas:

# Loop over lines
polydata.GetLines().InitTraversal()
out_lines.InitTraversal()

in_points = polydata.GetPoints()

for line_id in range(polydata.GetNumberOfLines()):

point_ids = vtk.vtkIdList()
polydata.GetLines().GetNextCell(point_ids)

if _verbose:
if line_id % 100 == 0:
print(f"<{os.path.basename(__file__)}> Line: {line_id} / {polydata.GetNumberOfLines()}")

# Get points for each point id and add to output polydata
cell_point_ids = vtk.vtkIdList()

for point_id in range(point_ids.GetNumberOfIds()):

point = in_points.GetPoint(point_ids.GetId(point_id))
idx = out_points.InsertNextPoint(point)
cell_point_ids.InsertNextId(idx)

out_lines.InsertNextCell(cell_point_ids)

# Put data into output polydata
out_polydata.SetLines(out_lines)
out_polydata.SetPoints(out_points)

return out_polydata


def downsample(inpd, output_number_of_lines, return_indices=False, preserve_point_data=False, preserve_cell_data=True, initial_indices=None, verbose=True, random_seed=None):
""" Random (down)sampling of fibers without replacement. """

Expand Down