Skip to content

Commit

Permalink
try5
Browse files Browse the repository at this point in the history
  • Loading branch information
dkazanc committed Apr 30, 2024
1 parent 0e4908d commit fb8ce77
Showing 1 changed file with 29 additions and 10 deletions.
39 changes: 29 additions & 10 deletions httomolibgpu/misc/corr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,32 @@
# ---------------------------------------------------------------------------
""" Module for data correction """


import numpy as xp

cupy_run = False

try:
import cupy as cp

# import nvtx
cupy_run = True
try:
cp.cuda.Device(0).compute_capability
cupy_run = True

except cp.cuda.runtime.CUDARuntimeError:
print("Cupy library is a required dependency for HTTomolibgpu, please install")
import numpy as np
except ImportError:
print("Cupy library is a required dependency for HTTomolibgpu, please install")
import numpy as np

# cupy_run = False
# try:
# import cupy as cp

# # import nvtx
# cupy_run = True
# except ImportError:
# print("Cupy library is a required dependency for HTTomolibgpu, please install")

try:
from cucim.skimage.filters import median
Expand All @@ -38,7 +56,8 @@
)

from typing import Tuple
import numpy as np

# import numpy as np
from numpy import float32

if cupy_run:
Expand All @@ -52,11 +71,11 @@

# @nvtx.annotate()
def median_filter(
data: cp.ndarray,
data: xp.ndarray,
kernel_size: int = 3,
axis: int = 0,
dif: float = 0.0,
) -> cp.ndarray:
) -> xp.ndarray:
"""
Apply 2D or 3D median or dezinger (when dif>0) filter to a 3D array.
Expand Down Expand Up @@ -101,7 +120,7 @@ def median_filter(
raise ValueError("The axis should be 0,1,2 or None for full 3d processing")

dz, dy, dx = data.shape
output = cp.empty(data.shape, dtype=input_type, order="C")
output = xp.empty(data.shape, dtype=input_type, order="C")

if axis == 0:
for j in range(dz):
Expand Down Expand Up @@ -142,7 +161,7 @@ def median_filter(
output = data;
}
"""
thresholding_kernel = cp.ElementwiseKernel(
thresholding_kernel = xp.ElementwiseKernel(
"T data, raw float32 dif",
"T output",
kernel,
Expand All @@ -155,8 +174,8 @@ def median_filter(


def remove_outlier(
data: cp.ndarray, kernel_size: int = 3, axis: int = 0, dif: float = 0.1
) -> cp.ndarray:
data: xp.ndarray, kernel_size: int = 3, axis: int = 0, dif: float = 0.1
) -> xp.ndarray:
"""
Selectively applies 3D median filter to a 3D array to remove outliers. Also called a dezinger.
Expand Down

0 comments on commit fb8ce77

Please sign in to comment.