Skip to content

Commit

Permalink
Merge pull request #9 from skoudoro/OPT_trackers
Browse files Browse the repository at this point in the history
Rebase for openmp on mac + function pointer nogil
  • Loading branch information
gabknight authored Feb 9, 2024
2 parents fb4005f + 4fd1b42 commit 4619c10
Show file tree
Hide file tree
Showing 16 changed files with 595 additions and 51 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,13 @@ jobs:
- name: Build the wheel
run: python -m cibuildwheel --output-dir dist
env:
CIBW_BEFORE_ALL_MACOS: "brew install llvm libomp"
CIBW_BUILD: ${{ matrix.cibw_python }}
CIBW_ARCHS_MACOS: ${{ matrix.cibw_arch }}
CIBW_TEST_SKIP: "*_aarch64 *-macosx_arm64"
CIBW_MANYLINUX_X86_64_IMAGE: ${{ matrix.cibw_manylinux }}
CIBW_MANYLINUX_I686_IMAGE: ${{ matrix.cibw_manylinux }}
CC: clang
- name: Rename Python version
run: echo "PY_VERSION=$(echo ${{ matrix.cibw_python }} | cut -d- -f1)" >> $GITHUB_ENV
- uses: actions/upload-artifact@v4
Expand Down
17 changes: 7 additions & 10 deletions .github/workflows/test_template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,12 @@ jobs:
else
tools/ci/install_dependencies.sh
fi
# No need to update mingw-w64, we use msvc
# mingw-w64 does not manage well openmp threads so we avoid it for now
# Note that compilation works with mingw-w64 but 2-3 tests fail due to openmp
# - name: Install rtools (mingw-w64)
# if runner.os == 'Windows'
# run: |
# choco install rtools -y --no-progress --force --version=4.0.0.20220206
# echo "/c/rtools40/ucrt64/bin;" >> $GITHUB_PATH
# echo "PKG_CONFIG_PATH=/c/opt/64/lib/pkgconfig;" >> $GITHUB_ENV
- name: Install OpenMP on macOS
if: runner.os == 'macOS'
run: |
brew install llvm
brew install libomp
echo "CC=clang" >> $GITHUB_ENV
- name: Install DIPY
run: |
if [ "${{ inputs.use-pre }}" == "true" ]; then
Expand All @@ -126,7 +123,7 @@ jobs:
tools/ci/run_tests.sh
fi
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
uses: codecov/codecov-action@v4
if: ${{ fromJSON(env.COVERAGE) }}
with:
token: ${{ secrets.CODECOV_TOKEN }}
Expand Down
1 change: 0 additions & 1 deletion dipy/direction/pmf.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ cdef class PmfGen:
double[:] pmf
double[:, :, :, :] data
double[:, :] vertices
object sphere

cpdef double[:] get_pmf(self, double[::1] point)
cdef double* get_pmf_c(self, double* point) noexcept nogil
Expand Down
1 change: 0 additions & 1 deletion dipy/direction/pmf.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ cdef class PmfGen:
double[:, :, :, :] data,
object sphere):
self.data = np.asarray(data, dtype=float, order='C')
self.sphere = sphere
self.vertices = np.asarray(sphere.vertices, dtype=float)

cpdef double[:] get_pmf(self, double[::1] point):
Expand Down
16 changes: 16 additions & 0 deletions dipy/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,22 @@ np_dep = declare_dependency(include_directories: inc_np)
# Define Optimisation for cython extensions
# ------------------------------------------------------------------------
omp = dependency('openmp', required: false)
if not omp.found() and meson.get_compiler('c').get_id() == 'clang'
# Check for libomp (OpenMP) using Homebrew
brew = find_program('brew', required : false)
if brew.found()
output = run_command(brew, 'list', 'libomp', check: true)
output = output.stdout().strip()
if output.contains('/libomp/')
omp_prefix = fs.parent(output.split('\n')[0])
message('OpenMP Found: YES (Manual search) - ', omp_prefix)
omp = declare_dependency(compile_args : ['-Xpreprocessor', '-fopenmp'],
link_args : ['-L' + omp_prefix + '/lib', '-lomp'],
include_directories : include_directories(omp_prefix / 'include')
)
endif
endif
endif

# SSE intrinsics
sse2_cflags = []
Expand Down
109 changes: 109 additions & 0 deletions dipy/tracking/fast_tracking.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
cimport numpy as cnp

from dipy.tracking.stopping_criterion cimport StoppingCriterion
from dipy.direction.pmf cimport PmfGen


cdef class TrackingParameters():
cdef:
int max_len
double step_size
double[:] voxel_size
double[3] inv_voxel_size


ctypedef int (*func_ptr)(double* point,
double* direction,
ProbabilisticTrackingParameters,
PmfGen) noexcept nogil

cpdef list generate_tractogram(double[:,::1] seed_positions,
double[:,::1] seed_directions,
StoppingCriterion sc,
ProbabilisticTrackingParameters params,
PmfGen pmf_gen)


cdef int generate_tractogram_c(double[:,::1] seed_positions,
double[:,::1] seed_directions,
int nbr_seeds,
StoppingCriterion sc,
ProbabilisticTrackingParameters params,
PmfGen pmf_gen,
func_ptr traker,

Check failure on line 33 in dipy/tracking/fast_tracking.pxd

View workflow job for this annotation

GitHub Actions / Check for spelling errors

traker ==> tracker
double[:,:,:] streamlines,
double[:] status)


cdef int generate_local_streamline(double* seed,
double* position,
double* stream,
func_ptr tracker,
# sc_ptr stopping_criterion,
# pmf_ptr pmf_gen,
StoppingCriterion sc,
ProbabilisticTrackingParameters params,
PmfGen pmf_gen) noexcept nogil

cdef int trilinear_interpolate4d_c(double[:, :, :, :] data,
double* point,
double* result) noexcept nogil

cdef int get_pmf(double* pmf,
double* point,
PmfGen pmf_gen,
double pmf_threshold,
int pmf_len) noexcept nogil


cdef class ProbabilisticTrackingParameters(TrackingParameters):
cdef:
double cos_similarity
double pmf_threshold
#PmfGen pmf_gen
int pmf_len
double[:, :] vertices


cdef int probabilistic_tracker(double* point,
double* direction,
ProbabilisticTrackingParameters params,
PmfGen pmf_gen) noexcept nogil

cdef class DeterministicTrackingParameters(ProbabilisticTrackingParameters):
pass


cdef int deterministic_maximum_tracker(double* point,
double* direction,
DeterministicTrackingParameters params,
PmfGen pmf_gen,)

cdef class ParallelTransportTrackingParameters(ProbabilisticTrackingParameters):
cdef:
double angular_separation
double data_support_exponent
double[3][3] frame
double k1
double k2
double k_small
double last_val
double last_val_cand
double max_angle
double max_curvature
double[3] position
int probe_count
double probe_length
double probe_normalizer
int probe_quality
double probe_radius
double probe_step_size
double[9] propagator
int rejection_sampling_max_try
int rejection_sampling_nbr_sample


cdef int parallel_transport_tracker(double* point,
double* direction,
ParallelTransportTrackingParameters params,
PmfGen pmf_gen)
Loading

0 comments on commit 4619c10

Please sign in to comment.