Skip to content

Commit

Permalink
Transport closing logic in tasks.py rather than transports.py?
Browse files Browse the repository at this point in the history
Otherwise we _always_ close the transport after requesting it??
  • Loading branch information
GeigerJ2 committed Oct 29, 2024
1 parent 9468f1e commit b0d077e
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 33 deletions.
45 changes: 25 additions & 20 deletions src/aiida/engine/processes/calcjobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,24 +54,6 @@
class PreSubmitException(Exception): # noqa: N818
"""Raise in the `do_upload` coroutine when an exception is raised in `CalcJob.presubmit`."""

async def get_transport(authinfo, transport_queue, cancellable):
transport_requests = transport_queue._transport_requests
last_transport_request = transport_requests.get(authinfo.pk, None)

# ? Refactor this into `obtain_transport` function
# ? Returns last transport if open, and awaits close callback handle, otherwise request new transport
if last_transport_request is None or transport_queue._last_request_special:
# This is the previous behavior
with transport_queue.request_transport(authinfo) as request:
transport = await cancellable.with_interrupt(request)
else:
transport = authinfo.get_transport()
if not transport.is_open:
with transport_queue.request_transport(authinfo) as request:
transport = await cancellable.with_interrupt(request)
else:
transport_queue._last_request_special = True

async def task_upload_job(process: 'CalcJob', transport_queue: TransportQueue, cancellable: InterruptableFuture):
"""Transport task that will attempt to upload the files of a job calculation to the remote.
Expand Down Expand Up @@ -158,11 +140,34 @@ async def task_submit_job(node: CalcJobNode, transport_queue: TransportQueue, ca
max_attempts = get_config_option(MAX_ATTEMPTS_OPTION)

authinfo = node.get_authinfo()
authinfo_pk = authinfo.pk

transport_request = transport_queue._transport_requests.get(authinfo.pk, None)
open_transport = transport_queue._open_transports.get(authinfo.pk, None)

if open_transport is not None: # and not transport_queue._last_request_special:
transport = open_transport
transport_queue._last_request_special = True
elif transport_request is None: # or transport_queue._last_request_special:
# This is the previous behavior
with transport_queue.request_transport(authinfo) as request:
transport = await cancellable.with_interrupt(request)
else:
pass

async def do_submit():
transport_request = transport_queue._transport_requests.get(authinfo.pk, None)
open_transport = transport_queue._open_transports.get(authinfo.pk, None)

transport = get_transport(authinfo=authinfo, transport_queue=transport_queue, cancellable=cancellable)
print('a')
if open_transport is not None: # and not transport_queue._last_request_special:
transport = open_transport
transport_queue._last_request_special = True
elif transport_request is None: # or transport_queue._last_request_special:
# This is the previous behavior
with transport_queue.request_transport(authinfo) as request:
transport = await cancellable.with_interrupt(request)
else:
pass

return execmanager.submit_calculation(node, transport)

Expand Down
32 changes: 20 additions & 12 deletions src/aiida/engine/transports.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None):
self._last_close_time = None
self._last_request_special: bool = False
self._close_callback_handle = None
self._open_transports: Dict[Hashable, Transport] = {}
# self._last_transport_request: Dict[Hashable, str] = {}

@property
Expand Down Expand Up @@ -102,6 +103,7 @@ def do_open():
try:
transport.open()
self._last_open_time = timezone.localtime(timezone.now())
self._open_transports[authinfo.pk] = transport
except Exception as exception:
_LOGGER.error('exception occurred while trying to open transport:\n %s', exception)
transport_request.future.set_exception(exception)
Expand Down Expand Up @@ -167,22 +169,28 @@ def do_open():
if transport_request.count == 0:
if transport_request.future.done():

def do_close():
"""Close the transport if conditions are met."""
transport_request.future.result().close()
self._last_close_time = timezone.localtime(timezone.now())
# ? Why is all this logic in the `request_transport` method?
# ? Shouldn't the logic to close a transport be outside, such that the transport is being closed
# ? once it was actually used???
pass
# def do_close():
# """Close the transport if conditions are met."""
# transport_request.future.result().close()
# self._last_close_time = timezone.localtime(timezone.now())

close_timedelta = (timezone.localtime(timezone.now()) - self._last_open_time).total_seconds()
# close_timedelta = (timezone.localtime(timezone.now()) - self._last_open_time).total_seconds()

if close_timedelta > safe_open_interval:
# if close_timedelta < safe_open_interval:

# Also here logic when transport should be closed immediately, or when via call_later?
close_callback_handle = self._loop.call_soon(do_close, context=contextvars.Context())
self._last_close_time = timezone.localtime(timezone.now())
self._transport_requests.pop(authinfo.pk, None)
else:
close_callback_handle = self._loop.call_later(safe_open_interval, do_close, context=contextvars.Context())
self._transport_requests.pop(authinfo.pk, None)
# self._last_close_time = timezone.localtime(timezone.now())
# self._transport_requests.pop(authinfo.pk, None)
# close_callback_handle = self._loop.call_later(safe_open_interval, do_close, context=contextvars.Context())
# if close_timedelta > safe_open_interval:
# close_callback_handle = self._loop.call_soon(do_close, context=contextvars.Context())
# self._last_close_time = timezone.localtime(timezone.now())
# self._transport_requests.pop(authinfo.pk, None)
# self._transport_requests.pop(authinfo.pk, None)

# transport_request.transport_closer = close_callback_handle

Expand Down
1 change: 0 additions & 1 deletion src/aiida/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ async def exponential_backoff_retry(

result: Any = None
coro = ensure_coroutine(fct)
print('a')
interval = initial_interval

for iteration in range(max_attempts):
Expand Down

0 comments on commit b0d077e

Please sign in to comment.