Skip to content

Commit

Permalink
Move async dependencies to core (#812)
Browse files Browse the repository at this point in the history
* Move exceptions to core library

* Move get_address to core library

* Move nvtx_annotate to core library

* Add missing copyright headers

* Add missing ucp._libs.utils with nvtx_annotate definition

* Remove extra spaces from copyright

Co-authored-by: Mads R. B. Kristensen <[email protected]>
  • Loading branch information
pentschev and madsbk authored Dec 1, 2021
1 parent 1cd6de1 commit 9221fba
Show file tree
Hide file tree
Showing 28 changed files with 196 additions and 166 deletions.
3 changes: 2 additions & 1 deletion ucp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from ._version import get_versions as _get_versions # noqa
from .core import * # noqa
from .core import get_ucx_version # noqa
from .utils import get_address, get_ucxpy_logger # noqa
from .utils import get_ucxpy_logger # noqa
from ._libs.ucx_api import get_address # noqa

if "UCX_SOCKADDR_TLS_PRIORITY" not in os.environ and get_ucx_version() < (1, 11, 0):
logger.debug(
Expand Down
4 changes: 4 additions & 0 deletions ucp/_libs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.
# See file LICENSE for terms.

from .utils import nvtx_annotate # noqa
53 changes: 53 additions & 0 deletions ucp/_libs/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.
# See file LICENSE for terms.

import contextlib
import logging

logger = logging.getLogger("ucx")


@contextlib.contextmanager
def log_errors(reraise_exception=False):
try:
yield
except BaseException as e:
logger.exception(e)
if reraise_exception:
raise


class UCXBaseException(Exception):
pass


class UCXError(UCXBaseException):
pass


class UCXConfigError(UCXError):
pass


class UCXWarning(UserWarning):
pass


class UCXCloseError(UCXBaseException):
pass


class UCXCanceled(UCXBaseException):
pass


class UCXConnectionReset(UCXBaseException):
pass


class UCXMsgTruncated(UCXBaseException):
pass


class UCXNotConnected(UCXBaseException):
pass
6 changes: 3 additions & 3 deletions ucp/_libs/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import pytest

import ucp.exceptions
from ucp._libs import ucx_api
from ucp._libs.arr import Array
from ucp._libs.exceptions import UCXConfigError


def test_get_config():
Expand Down Expand Up @@ -46,13 +46,13 @@ def test_init_options():
)
def test_init_unknown_option():
options = {"UNKNOWN_OPTION": "3M"}
with pytest.raises(ucp.exceptions.UCXConfigError):
with pytest.raises(UCXConfigError):
ucx_api.UCXContext(options)


def test_init_invalid_option():
options = {"SEG_SIZE": "invalid-size"}
with pytest.raises(ucp.exceptions.UCXConfigError):
with pytest.raises(UCXConfigError):
ucx_api.UCXContext(options)


Expand Down
5 changes: 4 additions & 1 deletion ucp/_libs/tests/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ def _client(port, endpoint_error_handling, server_close_callback):
ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG,))
worker = ucx_api.UCXWorker(ctx)
ep = ucx_api.UCXEndpoint.create(
worker, "localhost", port, endpoint_error_handling=endpoint_error_handling,
worker,
ucx_api.get_address(),
port,
endpoint_error_handling=endpoint_error_handling,
)
if server_close_callback is True:
ep.close()
Expand Down
5 changes: 4 additions & 1 deletion ucp/_libs/tests/test_server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ def _echo_client(msg_size, port, endpoint_error_handling):
ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG,))
worker = ucx_api.UCXWorker(ctx)
ep = ucx_api.UCXEndpoint.create(
worker, "localhost", port, endpoint_error_handling=endpoint_error_handling,
worker,
ucx_api.get_address(),
port,
endpoint_error_handling=endpoint_error_handling,
)
send_msg = bytes(os.urandom(msg_size))
recv_msg = bytearray(msg_size)
Expand Down
5 changes: 4 additions & 1 deletion ucp/_libs/tests/test_server_client_am.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,10 @@ def _echo_client(msg_size, datatype, port, endpoint_error_handling):
worker.register_am_allocator(data["allocator"], data["memory_type"])

ep = ucx_api.UCXEndpoint.create(
worker, "localhost", port, endpoint_error_handling=endpoint_error_handling,
worker,
ucx_api.get_address(),
port,
endpoint_error_handling=endpoint_error_handling,
)

# The wireup message is sent to ensure endpoints are connected, otherwise
Expand Down
3 changes: 1 addition & 2 deletions ucp/_libs/transfer_am.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@ from libc.stdint cimport uintptr_t
from libc.stdlib cimport free

from .arr cimport Array
from .exceptions import UCXCanceled, UCXError, log_errors
from .ucx_api_dep cimport *

from ..exceptions import UCXCanceled, UCXError, log_errors

logger = logging.getLogger("ucx")


Expand Down
3 changes: 1 addition & 2 deletions ucp/_libs/transfer_common.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@

from libc.stdint cimport uintptr_t

from .exceptions import UCXCanceled, UCXError, log_errors
from .ucx_api_dep cimport *

from ..exceptions import UCXCanceled, UCXError, log_errors


# This callback function is currently needed by stream_send_nb and
# tag_send_nb transfer functions, as well as UCXEndpoint and UCXWorker
Expand Down
5 changes: 2 additions & 3 deletions ucp/_libs/transfer_stream.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@
from libc.stdint cimport uintptr_t

from .arr cimport Array
from .ucx_api_dep cimport *

from ..exceptions import (
from .exceptions import (
UCXCanceled,
UCXError,
UCXMsgTruncated,
UCXNotConnected,
log_errors,
)
from .ucx_api_dep cimport *


def stream_send_nb(
Expand Down
3 changes: 1 addition & 2 deletions ucp/_libs/transfer_tag.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
from libc.stdint cimport uintptr_t

from .arr cimport Array
from .exceptions import UCXCanceled, UCXError, UCXMsgTruncated, log_errors
from .ucx_api_dep cimport *

from ..exceptions import UCXCanceled, UCXError, UCXMsgTruncated, log_errors


def tag_send_nb(
UCXEndpoint ep,
Expand Down
3 changes: 3 additions & 0 deletions ucp/_libs/ucx_api.pyx
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.
# See file LICENSE for terms.

include "packed_remote_key.pyx"
include "transfer_am.pyx"
include "transfer_common.pyx"
Expand Down
3 changes: 1 addition & 2 deletions ucp/_libs/ucx_endpoint.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@ import warnings
from libc.stdint cimport uintptr_t
from libc.stdio cimport FILE

from .exceptions import UCXCanceled, UCXConnectionReset, UCXError
from .ucx_api_dep cimport *

from ..exceptions import UCXCanceled, UCXConnectionReset, UCXError

logger = logging.getLogger("ucx")


Expand Down
3 changes: 1 addition & 2 deletions ucp/_libs/ucx_listener.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@

from libc.stdint cimport uint16_t, uintptr_t

from .exceptions import log_errors
from .ucx_api_dep cimport *

from ..exceptions import log_errors


cdef void _listener_callback(ucp_conn_request_h conn_request, void *args) with gil:
"""Callback function used by UCXListener"""
Expand Down
3 changes: 2 additions & 1 deletion ucp/_libs/ucx_memory_handle.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2021 UT-Battelle, LLC. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, UT-Battelle, LLC. All rights reserved.
# See file LICENSE for terms.

# cython: language_level=3
Expand Down
3 changes: 1 addition & 2 deletions ucp/_libs/ucx_request.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
from cpython.ref cimport Py_DECREF, Py_INCREF, PyObject
from libc.stdint cimport uintptr_t

from .exceptions import UCXError, UCXMsgTruncated
from .ucx_api_dep cimport *

from ..exceptions import UCXError, UCXMsgTruncated


# Counter used as UCXRequest UIDs
cdef unsigned int _ucx_py_request_counter = 0
Expand Down
3 changes: 2 additions & 1 deletion ucp/_libs/ucx_rkey.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2021 UT-Battelle, LLC. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, UT-Battelle, LLC. All rights reserved.
# See file LICENSE for terms.

# cython: language_level=3
Expand Down
3 changes: 1 addition & 2 deletions ucp/_libs/ucx_rma.pyx
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from io import RawIOBase

from .arr cimport Array
from .exceptions import UCXError
from .ucx_api_dep cimport *

from ..exceptions import UCXError


class RemoteMemory:
"""This class wraps all of the rkey meta data and remote memory locations to do
Expand Down
5 changes: 2 additions & 3 deletions ucp/_libs/ucx_worker.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@ from libc.stdint cimport uint16_t, uintptr_t
from libc.stdio cimport FILE
from libc.string cimport memset

from .exceptions import UCXError
from .ucx_api_dep cimport *

from ..exceptions import UCXError
from ..utils import nvtx_annotate
from .utils import nvtx_annotate

logger = logging.getLogger("ucx")

Expand Down
3 changes: 1 addition & 2 deletions ucp/_libs/ucx_worker_cb.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@ from cython cimport boundscheck, initializedcheck, nonecheck, wraparound
from libc.stdint cimport uintptr_t
from libc.string cimport memcpy

from .exceptions import UCXCanceled, UCXError, log_errors
from .ucx_api_dep cimport *

from ..exceptions import UCXCanceled, UCXError, log_errors

logger = logging.getLogger("ucx")


Expand Down
3 changes: 3 additions & 0 deletions ucp/_libs/ucxio.pyx
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (c) 2021, UT-Battelle, LLC. All rights reserved.
# See file LICENSE for terms.

from io import RawIOBase

from .arr cimport Array
Expand Down
12 changes: 12 additions & 0 deletions ucp/_libs/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# See file LICENSE for terms.

try:
from nvtx import annotate as nvtx_annotate
except ImportError:
# If nvtx module is not installed, `annotate` yields only.
from contextlib import contextmanager

@contextmanager
def nvtx_annotate(message=None, color=None, domain=None):
yield
70 changes: 68 additions & 2 deletions ucp/_libs/utils.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@

# cython: language_level=3

import fcntl
import glob
import os
import socket
import struct

from cpython.buffer cimport PyBUF_FORMAT, PyBUF_ND, PyBUF_WRITABLE
from libc.stdio cimport (
FILE,
Expand All @@ -18,10 +24,9 @@ from libc.stdio cimport (
)
from libc.stdlib cimport free

from .exceptions import UCXConfigError, UCXError
from .ucx_api_dep cimport *

from ..exceptions import UCXConfigError, UCXError


cdef FILE * create_text_fd():
cdef FILE *text_fd = tmpfile()
Expand Down Expand Up @@ -182,3 +187,64 @@ def is_am_supported():
return get_ucx_version() >= (1, 11, 0)
ELSE:
return False


def get_address(ifname=None):
"""
Get the address associated with a network interface.
Parameters
----------
ifname : str
The network interface name to find the address for.
If None, it uses the value of environment variable `UCXPY_IFNAME`
and if `UCXPY_IFNAME` is not set it defaults to "ib0"
An OSError is raised for invalid interfaces.
Returns
-------
address : str
The inet addr associated with an interface.
Examples
--------
>>> get_address()
'10.33.225.160'
>>> get_address(ifname='lo')
'127.0.0.1'
"""

def _get_address(ifname):
ifname = ifname.encode()
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
return socket.inet_ntoa(
fcntl.ioctl(
s.fileno(), 0x8915, struct.pack("256s", ifname[:15]) # SIOCGIFADDR
)[20:24]
)

def _try_interfaces():
prefix_priority = ["ib", "eth", "en"]
iftypes = {p: [] for p in prefix_priority}
for i in glob.glob("/sys/class/net/*"):
name = i.split("/")[-1]
for p in prefix_priority:
if name.startswith(p):
iftypes[p].append(name)
for p in prefix_priority:
iftype = iftypes[p]
iftype.sort()
for i in iftype:
try:
return _get_address(i)
except OSError:
pass

if ifname is None:
ifname = os.environ.get("UCXPY_IFNAME")

if ifname is not None:
return _get_address(ifname)
else:
return _try_interfaces()
3 changes: 3 additions & 0 deletions ucp/_libs/utils_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.
# See file LICENSE for terms.

import multiprocessing as mp

from ucp._libs import ucx_api
Expand Down
Loading

0 comments on commit 9221fba

Please sign in to comment.