Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dask] hold ports until training #5890

Merged
merged 9 commits into from
Jun 20, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 41 additions & 67 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

It is based on dask-lightgbm, which was based on dask-xgboost.
"""
import operator
import socket
import time
from collections import defaultdict
from copy import deepcopy
from enum import Enum, auto
Expand Down Expand Up @@ -38,18 +40,20 @@
_PredictionDtype = Union[Type[np.float32], Type[np.float64], Type[np.int32], Type[np.int64]]


class _HostWorkers:
class _RemoteSocket:
def acquire(self):
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.bind(('', 0))
return self.socket.getsockname()[1]

def __init__(self, default: str, all_workers: List[str]):
self.default = default
self.all_workers = all_workers
def release(self):
self.socket.close()

def __eq__(self, other: object) -> bool:
return (
isinstance(other, type(self))
and self.default == other.default
and self.all_workers == other.all_workers
)

def _acquire_port() -> Tuple[_RemoteSocket, int]:
s = _RemoteSocket()
port = s.acquire()
return s, port


class _DatasetNames(Enum):
Expand Down Expand Up @@ -83,49 +87,9 @@ def _get_dask_client(client: Optional[Client]) -> Client:
return client


def _find_n_open_ports(n: int) -> List[int]:
"""Find n random open ports on localhost.

Returns
-------
ports : list of int
n random open ports on localhost.
"""
sockets = []
for _ in range(n):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(('', 0))
sockets.append(s)
ports = []
for s in sockets:
ports.append(s.getsockname()[1])
s.close()
return ports


def _group_workers_by_host(worker_addresses: Iterable[str]) -> Dict[str, _HostWorkers]:
"""Group all worker addresses by hostname.

Returns
-------
host_to_workers : dict
mapping from hostname to all its workers.
"""
host_to_workers: Dict[str, _HostWorkers] = {}
for address in worker_addresses:
hostname = urlparse(address).hostname
if not hostname:
raise ValueError(f"Could not parse host name from worker address '{address}'")
if hostname not in host_to_workers:
host_to_workers[hostname] = _HostWorkers(default=address, all_workers=[address])
else:
host_to_workers[hostname].all_workers.append(address)
return host_to_workers


def _assign_open_ports_to_workers(
client: Client,
host_to_workers: Dict[str, _HostWorkers]
workers: List[str],
) -> Dict[str, int]:
"""Assign an open port to each worker.

Expand All @@ -134,22 +98,27 @@ def _assign_open_ports_to_workers(
worker_to_port: dict
mapping from worker address to an open port.
"""
host_ports_futures = {}
for hostname, workers in host_to_workers.items():
n_workers_in_host = len(workers.all_workers)
host_ports_futures[hostname] = client.submit(
_find_n_open_ports,
n=n_workers_in_host,
workers=[workers.default],
pure=False,
# Acquire port in worker
worker_to_future = {}
for worker in workers:
worker_to_future[worker] = client.submit(
_acquire_port,
workers=[worker],
allow_other_workers=False,
pure=False,
)
found_ports = client.gather(host_ports_futures)
worker_to_port = {}
for hostname, workers in host_to_workers.items():
for worker, port in zip(workers.all_workers, found_ports[hostname]):
worker_to_port[worker] = port
return worker_to_port

# schedule futures to retrieve each element of the tuple
worker_to_socket_future = {}
worker_to_port_future = {}
for worker, socket_future in worker_to_future.items():
worker_to_socket_future[worker] = client.submit(operator.itemgetter(0), socket_future)
worker_to_port_future[worker] = client.submit(operator.itemgetter(1), socket_future)
jameslamb marked this conversation as resolved.
Show resolved Hide resolved

# retrieve ports
worker_to_port = client.gather(worker_to_port_future)
jmoralez marked this conversation as resolved.
Show resolved Hide resolved

return worker_to_socket_future, worker_to_port


def _concat(seq: List[_DaskPart]) -> _DaskPart:
Expand Down Expand Up @@ -190,6 +159,7 @@ def _train_part(
num_machines: int,
return_model: bool,
time_out: int,
remote_socket: _RemoteSocket,
**kwargs: Any
) -> Optional[LGBMModel]:
network_params = {
Expand Down Expand Up @@ -319,6 +289,9 @@ def _train_part(
if eval_class_weight:
kwargs['eval_class_weight'] = [eval_class_weight[i] for i in eval_component_idx]

if remote_socket is not None:
remote_socket.release()
jmoralez marked this conversation as resolved.
Show resolved Hide resolved
time.sleep(0.1)
jmoralez marked this conversation as resolved.
Show resolved Hide resolved
model = model_factory(**params)
try:
if is_ranker:
Expand Down Expand Up @@ -777,6 +750,7 @@ def _train(
machines = params.pop("machines")

# figure out network params
worker_to_socket_future = {}
worker_addresses = worker_map.keys()
if machines is not None:
_log_info("Using passed-in 'machines' parameter")
Expand All @@ -802,8 +776,7 @@ def _train(
}
else:
_log_info("Finding random open ports for workers")
host_to_workers = _group_workers_by_host(worker_map.keys())
worker_address_to_port = _assign_open_ports_to_workers(client, host_to_workers)
worker_to_socket_future, worker_address_to_port = _assign_open_ports_to_workers(client, list(worker_map.keys()))

machines = ','.join([
f'{urlparse(worker_address).hostname}:{port}'
Expand Down Expand Up @@ -831,6 +804,7 @@ def _train(
local_listen_port=worker_address_to_port[worker],
num_machines=num_machines,
time_out=params.get('time_out', 120),
remote_socket=worker_to_socket_future.get(worker, None),
return_model=(worker == master_worker),
workers=[worker],
allow_other_workers=False,
Expand Down
47 changes: 7 additions & 40 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import inspect
import random
import socket
import time
from itertools import groupby
from os import getenv
from platform import machine
Expand Down Expand Up @@ -519,50 +520,13 @@ def test_classifier_custom_objective(output, task, cluster):
assert_eq(p1_proba, p1_proba_local)


def test_group_workers_by_host():
hosts = [f'0.0.0.{i}' for i in range(2)]
workers = [f'tcp://{host}:{p}' for p in range(2) for host in hosts]
expected = {
host: lgb.dask._HostWorkers(
default=f'tcp://{host}:0',
all_workers=[f'tcp://{host}:0', f'tcp://{host}:1']
)
for host in hosts
}
host_to_workers = lgb.dask._group_workers_by_host(workers)
assert host_to_workers == expected


def test_group_workers_by_host_unparseable_host_names():
workers_without_protocol = ['0.0.0.1:80', '0.0.0.2:80']
with pytest.raises(ValueError, match="Could not parse host name from worker address '0.0.0.1:80'"):
lgb.dask._group_workers_by_host(workers_without_protocol)


def test_machines_to_worker_map_unparseable_host_names():
workers = {'0.0.0.1:80': {}, '0.0.0.2:80': {}}
machines = "0.0.0.1:80,0.0.0.2:80"
with pytest.raises(ValueError, match="Could not parse host name from worker address '0.0.0.1:80'"):
lgb.dask._machines_to_worker_map(machines=machines, worker_addresses=workers.keys())


def test_assign_open_ports_to_workers(cluster):
with Client(cluster) as client:
workers = client.scheduler_info()['workers'].keys()
n_workers = len(workers)
host_to_workers = lgb.dask._group_workers_by_host(workers)
for _ in range(25):
worker_address_to_port = lgb.dask._assign_open_ports_to_workers(client, host_to_workers)
found_ports = worker_address_to_port.values()
assert len(found_ports) == n_workers
# check that found ports are different for same address (LocalCluster)
assert len(set(found_ports)) == len(found_ports)
# check that the ports are indeed open
for port in found_ports:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', port))


def test_training_does_not_fail_on_port_conflicts(cluster):
with Client(cluster) as client:
_, _, _, _, dX, dy, dw, _ = _create_data('binary-classification', output='array')
Expand Down Expand Up @@ -1588,15 +1552,18 @@ def test_network_params_not_required_but_respected_if_given(task, listen_port, c
assert 'machines' not in params

# model 2 - machines given
workers = list(client.scheduler_info()['workers'])
workers_hostname = _get_workers_hostname(cluster)
n_workers = len(client.scheduler_info()['workers'])
open_ports = lgb.dask._find_n_open_ports(n_workers)
remote_sockets, open_ports = lgb.dask._assign_open_ports_to_workers(client, workers)
for s in remote_sockets.values():
s.release()
time.sleep(0.1)
dask_model2 = dask_model_factory(
n_estimators=5,
num_leaves=5,
machines=",".join([
f"{workers_hostname}:{port}"
for port in open_ports
for port in open_ports.values()
]),
)

Expand Down