Skip to content

Commit

Permalink
Move async fetch methods out of Trolley and Downloader, into individu…
Browse files Browse the repository at this point in the history
…al Downloader implementations (rad69 and wado_uri) Fixes #43
  • Loading branch information
sjoerdk committed Apr 19, 2024
1 parent 3171f50 commit 4424b63
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 129 deletions.
24 changes: 0 additions & 24 deletions dicomtrolley/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,30 +523,6 @@ def datasets(self, objects: Sequence[DICOMDownloadable]):
"""
raise NotImplementedError()

def datasets_async(
self, instances: Sequence[InstanceReference], max_workers=None
):
"""Retrieve each instance in multiple threads
Parameters
----------
instances: Sequence[InstanceReference]
Retrieve dataset for each of these instances
max_workers: int, optional
Use this number of workers in ThreadPoolExecutor. Defaults to
default for ThreadPoolExecutor
Raises
------
DICOMTrolleyError
When a server response cannot be parsed as DICOM
Returns
-------
Iterator[Dataset, None, None]
"""
raise NotImplementedError()


class QueryLevels(str, Enum):
"""Used in dicom queries to indicate how rich the search should be"""
Expand Down
152 changes: 93 additions & 59 deletions dicomtrolley/rad69.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@ def __init__(
self,
session,
url,
http_chunk_size=5242880,
request_per_series=True,
errors_to_ignore=None,
use_async=False,
max_workers=4,
):
"""
Parameters
Expand All @@ -56,34 +57,65 @@ def __init__(
A logged in session over which rad69 calls can be made
url: str
rad69 endpoint, including protocol and port. Like https://server:2525/rids
http_chunk_size: int, optional
Number of bytes to read each time when streaming chunked rad69 responses.
Defaults to 5MB (5242880 bytes)
request_per_series: bool, optional
If true, split rad69 requests per series when downloading. If false,
request all instances at once. Splitting reduces load on server.
defaults to True.
errors_to_ignore: List[Type], optional
Errors of this type encountered during download are caught and skipped.
Defaults to empty list, meaning any error is propagated
use_async: bool, optional
If True, download will split instances into chunks and download each
chunk in a separate thread. If False, use single thread Defaults to False
max_workers: int, optional
Only used of use_async=True. Number of workers to use for multi-threading
"""

self.session = session
self.url = url
self.http_chunk_size = http_chunk_size

# Number of bytes to read each time when streaming chunked rad69 responses.
# Defaults to 5MB (5242880 bytes)
self.http_chunk_size = 5242880

if errors_to_ignore is None:
errors_to_ignore = []
self.errors_to_ignore = errors_to_ignore
self.template = RAD69_SOAP_REQUEST_TEMPLATE
self.post_headers = {"Content-Type": "application/soap+xml"}
self.request_per_series = request_per_series
self.use_async = use_async
self.max_workers = max_workers

def datasets(self, objects: Sequence[DICOMDownloadable]):
"""Retrieve all instances via rad69
A Rad69 request typically contains multiple instances. The data for all
instances is then streamed back as one multipart http response
Raises
------
NonInstanceParameterError
If objects contain non-instance targets like a StudyInstanceUID.
Rad69 can only download instances
Returns
-------
Iterator[Dataset, None, None]
"""
if self.use_async:
yield from self.datasets_async(
objects, max_workers=self.max_workers
)
else:
yield from self.datasets_single_thread(objects)

def datasets_single_thread(self, objects: Sequence[DICOMDownloadable]):
"""Retrieve all instances via rad69, without async
A Rad69 request typically contains multiple instances. The data for all
instances is then streamed back as one multipart http response
Raises
------
NonInstanceParameterError
Expand Down Expand Up @@ -111,6 +143,62 @@ def datasets(self, objects: Sequence[DICOMDownloadable]):
else:
return self.download_iterator(instances)

def datasets_async(
self, objects: Sequence[DICOMDownloadable], max_workers
):
"""Split instances into chunks and retrieve each chunk in separate thread
Parameters
----------
objects: Sequence[DICOMDownloadable]
Retrieve dataset for each instance in these objects
max_workers: int
Use this number of workers in ThreadPoolExecutor. Defaults to
default for ThreadPoolExecutor
Notes
-----
rad69 allows any number of slices to be combined in one request. The response
is a chunked multi-part http response with all image data. Requesting each
slice individually is inefficient. Requesting all slices in one thread might
limit speed. Somewhere in the middle seems the best bet for optimal speed.
This function splits all instances between the available workers and lets
workers process the response streams.
Raises
------
DICOMTrolleyError
When a server response cannot be parsed as DICOM
Returns
-------
Iterator[Dataset, None, None]
"""
instances = to_instance_refs(objects) # raise exception if needed

# max_workers=None means let the executor figure it out. But for rad69 we
# still need to determine how many instances to retrieve at once with each
# worker. Unlimited workers make no sense here. Just use a single thread.
if max_workers is None:
max_workers = 1

with FuturesSession(
session=self.session,
executor=ThreadPoolExecutor(max_workers=max_workers),
) as futures_session:
futures = []
for instance_bin in self.split_instances(instances, max_workers):
futures.append(
futures_session.post(
url=self.url,
headers=self.post_headers,
data=self.create_instances_request(instance_bin),
)
)

for future in as_completed(futures):
yield from self.parse_rad69_response(future.result())

def series_download_iterator(
self, instances: Sequence[InstanceReference], index=0
):
Expand Down Expand Up @@ -320,60 +408,6 @@ def split_instances(instances: Sequence[InstanceReference], num_bins):
for i in range(0, len(instances), bin_size):
yield instances[i : i + bin_size]

def datasets_async(
self, instances: Sequence[InstanceReference], max_workers=4
):
"""Split instances into chunks and retrieve each chunk in separate thread
Parameters
----------
instances: Sequence[InstanceReference]
Retrieve dataset for each of these instances
max_workers: int, optional
Use this number of workers in ThreadPoolExecutor. Defaults to
default for ThreadPoolExecutor
Notes
-----
rad69 allows any number of slices to be combined in one request. The response
is a chunked multi-part http response with all image data. Requesting each
slice individually is inefficient. Requesting all slices in one thread might
limit speed. Somewhere in the middle seems the best bet for optimal speed.
This function splits all instances between the available workers and lets
workers process the response streams.
Raises
------
DICOMTrolleyError
When a server response cannot be parsed as DICOM
Returns
-------
Iterator[Dataset, None, None]
"""
# max_workers=None means let the executor figure it out. But for rad69 we
# still need to determine how many instances to retrieve at once with each
# worker. Unlimited workers make no sense here. Just use a single thread.
if max_workers is None:
max_workers = 1

with FuturesSession(
session=self.session,
executor=ThreadPoolExecutor(max_workers=max_workers),
) as futures_session:
futures = []
for instance_bin in self.split_instances(instances, max_workers):
futures.append(
futures_session.post(
url=self.url,
headers=self.post_headers,
data=self.create_instances_request(instance_bin),
)
)

for future in as_completed(futures):
yield from self.parse_rad69_response(future.result())


class Rad69ServerError(DICOMTrolleyError):
"""Represents a valid error response from a rad69 server"""
Expand Down
42 changes: 1 addition & 41 deletions dicomtrolley/trolley.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,8 @@ def download(
if not isinstance(objects, Sequence):
objects = [objects] # if just a single item to download is passed
logger.info(f"Downloading {len(objects)} object(s) to '{output_dir}'")
if use_async:
datasets = self.fetch_all_datasets_async(
objects=objects, max_workers=max_workers
)
else:
datasets = self.fetch_all_datasets(objects=objects)

for dataset in datasets:
for dataset in self.fetch_all_datasets(objects=objects):
self.storage.save(dataset=dataset, path=output_dir)

def fetch_all_datasets(self, objects: Sequence[DICOMDownloadable]):
Expand Down Expand Up @@ -181,37 +175,3 @@ def obtain_references(
)
references += study.contained_references(max_level=max_level)
return references

def fetch_all_datasets_async(self, objects, max_workers=None):
"""Get DICOM dataset for each instance given objects using multiple threads.
Parameters
----------
objects: Sequence[DICOMDownloadable]
get dataset for each instance contained in these objects
max_workers: int, optional
Max number of ThreadPoolExecutor workers to use. Defaults to
ThreadPoolExecutor default
Raises
------
DICOMTrolleyError
If getting or parsing of any instance fails
Returns
-------
Iterator[Dataset, None, None]
The downloaded dataset and the context that was used to download it
"""
try:
yield from self.downloader.datasets_async(
instances=objects,
max_workers=max_workers,
)
except NonInstanceParameterError:
yield from self.downloader.datasets_async(
instances=self.obtain_references(
objects=objects, max_level=DICOMObjectLevels.INSTANCE
),
max_workers=max_workers,
)
31 changes: 30 additions & 1 deletion dicomtrolley/wado_uri.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
class WadoURI(Downloader):
"""A connection to a WADO-URI server"""

def __init__(self, session, url):
def __init__(self, session, url, use_async=False, max_workers=None):
"""
Parameters
----------
Expand All @@ -35,10 +35,18 @@ def __init__(self, session, url):
url: str
WADO-URI endpoint, including protocol and port. Like
https://server:8080/wado
use_async: bool, optional
If True, download will split instances into chunks and download each
chunk in a separate thread. If False, use single thread Defaults to False
max_workers, Optional[int]
Only used of use_async=True. Number of workers to use for multi-threading.
Defaults to None, meaning ThreadPoolExecutor default number is used.
"""

self.session = session
self.url = url
self.use_async = use_async
self.max_workers = max_workers

@staticmethod
def to_wado_parameters(instance):
Expand Down Expand Up @@ -113,6 +121,27 @@ def datasets(self, objects: Sequence[DICOMDownloadable]):
-------
Iterator[Dataset, None, None]
Raises
------
NonInstanceParameterError
If objects contain non-instance targets like a StudyInstanceUID.
wado_uri can only download instances
"""
if self.use_async:
yield from self.datasets_async(
objects, max_workers=self.max_workers
)
else:
yield from self.datasets_single_thread(objects)

def datasets_single_thread(self, objects: Sequence[DICOMDownloadable]):
"""Retrieve each instance in objects
Returns
-------
Iterator[Dataset, None, None]
Raises
------
NonInstanceParameterError
Expand Down
4 changes: 2 additions & 2 deletions tests/test_rad69.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,8 @@ def test_wado_datasets_async(a_rad69, requests_mock):
InstanceReference(study_uid=1, series_uid=2, instance_uid=3),
InstanceReference(study_uid=4, series_uid=5, instance_uid=6),
]

datasets = [x for x in a_rad69.datasets_async(instances)]
a_rad69.use_async = True
datasets = [x for x in a_rad69.datasets(instances)]
assert len(datasets) == 2
assert datasets[0].PatientName == "patient1"
assert (
Expand Down
7 changes: 5 additions & 2 deletions tests/test_trolley.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ def test_trolley_get_dataset(a_trolley, some_mint_studies):
assert datasets[0].SOPInstanceUID == "bimini"


def test_trolley_get_dataset_async(a_trolley, some_mint_studies):
def test_trolley_get_dataset_async(a_mint, a_wado, some_mint_studies):
a_wado.use_async = True
a_trolley = Trolley(searcher=a_mint, downloader=a_wado)

a_trolley.downloader.datasets_async = Mock(
return_value=iter(
[
Expand All @@ -80,7 +83,7 @@ def test_trolley_get_dataset_async(a_trolley, some_mint_studies):
)
)

datasets = list(a_trolley.fetch_all_datasets_async(some_mint_studies))
datasets = list(a_trolley.fetch_all_datasets(some_mint_studies))
assert len(datasets) == 3
assert datasets[0].SOPInstanceUID == "bimini"

Expand Down

0 comments on commit 4424b63

Please sign in to comment.