diff --git a/bin/wm_concatenate_streamlines.py b/bin/wm_concatenate_streamlines.py new file mode 100644 index 00000000..fd763732 --- /dev/null +++ b/bin/wm_concatenate_streamlines.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +"""Concatenate streamlines contained in different tractograms. Point and cell +scalar and tensor data are not preserved. +""" + +import os +import argparse + +import vtk + +import whitematteranalysis as wma + + +def pipeline(fnames, out_fname, verbose=False): + + out_polydata = vtk.vtkPolyData() + + # Get subject identifier from unique input filename + for index, fname in enumerate(fnames): + sub_id = os.path.splitext(os.path.basename(fname))[0] + id_msg = f"{os.path.basename(__file__)} {index + 1} / {len(fnames)}" + msg = f"**Starting subject: {sub_id}" + print(id_msg + msg) + + # Read tractogram + msg = f"**Reading input: {sub_id}" + print(id_msg + msg) + + polydata = wma.io.read_polydata(fname) + print(f"Number of streamlines: {polydata.GetNumberOfLines()}") + + # Concatenate + out_polydata = wma.filter.concatenate_streamlines( + [out_polydata, polydata], + _verbose=verbose + ) + + print(f"Number of streamlines concatenated: {out_polydata.GetNumberOfLines()}") + + # Output + try: + print(f"Writing output polydata {out_fname}...") + wma.io.write_polydata(out_polydata, out_fname) + print(f"Wrote output {out_fname}.") + except: + print("Unknown exception in IO") + raise + + +def _build_arg_parser(): + + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawTextHelpFormatter + ) + parser.add_argument( + 'inputDirectory', + help='Directory containing tractography files (.*vtk|.*vtp).') + parser.add_argument( + 'outputFilename', + help='Output filename (.*vtk|.*vtp).') + parser.add_argument( + '-verbose', action='store_true', dest="flag_verbose", + help='Verbose. Run with -verbose to print operation information.') + + return parser + + +def _parse_args(parser): + + args = parser.parse_args() + return args + + +def main(): + parser = _build_arg_parser() + args = _parse_args(parser) + + if not os.path.isdir(args.inputDirectory): + print(f"Error: Input directory {args.inputDirectory} does not exist.") + exit() + + out_fname = args.outputFilename + if os.path.exists(out_fname): + msg = f"Output file {out_fname} exists. Remove or rename the output file." + parser.error(msg) + + print(f"{os.path.basename(__file__)}. Starting streamline concatenation.") + print("") + print("=====input directory======\n", args.inputDirectory) + print("=====output filename=====\n", args.outputFilename) + print("==========================") + + verbose = args.flag_verbose + + fnames = wma.io.list_vtk_files(args.inputDirectory) + + print(f"<{os.path.basename(__file__)}> Input number of files: ", len(fnames)) + + pipeline(fnames, out_fname, verbose=verbose) + + exit() + + +if __name__ == "__main__": + main() diff --git a/whitematteranalysis/filter.py b/whitematteranalysis/filter.py index 9e542cc0..f581b3ce 100644 --- a/whitematteranalysis/filter.py +++ b/whitematteranalysis/filter.py @@ -17,6 +17,8 @@ """ +import os + import vtk import numpy @@ -250,6 +252,62 @@ def preprocess(inpd, min_length_mm, else: return outpd + +def concatenate(polydatas, _verbose=False): + """Concatenate a list of polydatas. + + Parameters + ---------- + polydatas : list + vtkPolyData objects. + _verbose : bool, optional + True if processing needs to be printed. + + Returns + ------- + out_polydata : vtkPolyData + Concatenated polydata. + """ + + out_polydata = vtk.vtkPolyData() + out_lines = vtk.vtkCellArray() + out_points = vtk.vtkPoints() + + for polydata in polydatas: + + # Loop over lines + polydata.GetLines().InitTraversal() + out_lines.InitTraversal() + + in_points = polydata.GetPoints() + + for line_id in range(polydata.GetNumberOfLines()): + + point_ids = vtk.vtkIdList() + polydata.GetLines().GetNextCell(point_ids) + + if _verbose: + if line_id % 100 == 0: + print(f"<{os.path.basename(__file__)}> Line: {line_id} / {polydata.GetNumberOfLines()}") + + # Get points for each point id and add to output polydata + cell_point_ids = vtk.vtkIdList() + + for point_id in range(point_ids.GetNumberOfIds()): + + point = in_points.GetPoint(point_ids.GetId(point_id)) + idx = out_points.InsertNextPoint(point) + cell_point_ids.InsertNextId(idx) + + out_lines.InsertNextCell(cell_point_ids) + + # Put data into output polydata + out_polydata.SetLines(out_lines) + out_polydata.SetPoints(out_points) + + return out_polydata + + def downsample(inpd, output_number_of_lines, return_indices=False, preserve_point_data=False, preserve_cell_data=True, initial_indices=None, verbose=True, random_seed=None): """ Random (down)sampling of fibers without replacement. """