diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 679293a52d..d8e085c8e6 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -73,7 +73,8 @@ ExponentialReconnectionPolicy, HostDistance, RetryPolicy, IdentityTranslator, NoSpeculativeExecutionPlan, NoSpeculativeExecutionPolicy, DefaultLoadBalancingPolicy, - NeverRetryPolicy) + NeverRetryPolicy, ShardConnectionBackoffPolicy, NoDelayShardConnectionBackoffPolicy, + ShardConnectionScheduler) from cassandra.pool import (Host, _ReconnectionHandler, _HostReconnectionHandler, HostConnectionPool, HostConnection, NoConnectionsAvailable) @@ -757,6 +758,11 @@ def auth_provider(self, value): self._auth_provider = value + _shard_connection_backoff_policy: ShardConnectionBackoffPolicy + @property + def shard_connection_backoff_policy(self) -> ShardConnectionBackoffPolicy: + return self._shard_connection_backoff_policy + _load_balancing_policy = None @property def load_balancing_policy(self): @@ -1219,7 +1225,8 @@ def __init__(self, shard_aware_options=None, metadata_request_timeout=None, column_encryption_policy=None, - application_info:Optional[ApplicationInfoBase]=None + application_info: Optional[ApplicationInfoBase] = None, + shard_connection_backoff_policy: Optional[ShardConnectionBackoffPolicy] = None, ): """ ``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as @@ -1325,6 +1332,13 @@ def __init__(self, else: self._load_balancing_policy = default_lbp_factory() # set internal attribute to avoid committing to legacy config mode + if shard_connection_backoff_policy is not None: + if not isinstance(shard_connection_backoff_policy, ShardConnectionBackoffPolicy): + raise TypeError("shard_connection_backoff_policy should be an instance of class derived from ShardConnectionBackoffPolicy") + self._shard_connection_backoff_policy = shard_connection_backoff_policy + else: + self._shard_connection_backoff_policy = NoDelayShardConnectionBackoffPolicy() + if reconnection_policy is not None: if isinstance(reconnection_policy, type): raise TypeError("reconnection_policy should not be a class, it should be an instance of that class") @@ -2716,6 +2730,7 @@ def default_serial_consistency_level(self, cl): _metrics = None _request_init_callbacks = None _graph_paging_available = False + shard_connection_backoff_scheduler: ShardConnectionScheduler def __init__(self, cluster, hosts, keyspace=None): self.cluster = cluster @@ -2730,6 +2745,7 @@ def __init__(self, cluster, hosts, keyspace=None): self._protocol_version = self.cluster.protocol_version self.encoder = Encoder() + self.shard_connection_backoff_scheduler = cluster.shard_connection_backoff_policy.new_connection_scheduler(self.cluster.scheduler) # create connection pools in parallel self._initial_connect_futures = set() @@ -3340,6 +3356,7 @@ def shutdown(self): else: self.is_shutdown = True + self.shard_connection_backoff_scheduler.shutdown() # PYTHON-673. If shutdown was called shortly after session init, avoid # a race by cancelling any initial connection attempts haven't started, # then blocking on any that have. @@ -4456,7 +4473,12 @@ def shutdown(self): self.join() def schedule(self, delay, fn, *args, **kwargs): - self._insert_task(delay, (fn, args, tuple(kwargs.items()))) + if self.is_shutdown: + return + if delay: + self._insert_task(delay, (fn, args, tuple(kwargs.items()))) + else: + self._executor.submit(fn, *args, **kwargs) def schedule_unique(self, delay, fn, *args, **kwargs): task = (fn, args, tuple(kwargs.items())) diff --git a/cassandra/policies.py b/cassandra/policies.py index cb83238e87..ba65717d11 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -11,21 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import random +from __future__ import annotations +import random from collections import namedtuple -from functools import lru_cache +from functools import partial from itertools import islice, cycle, groupby, repeat import logging from random import randint, shuffle from threading import Lock import socket import warnings - -log = logging.getLogger(__name__) - +from typing import Callable, TYPE_CHECKING, Iterator, List, Tuple +from abc import ABC, abstractmethod from cassandra import WriteType as WT +if TYPE_CHECKING: + from cluster import _Scheduler + +log = logging.getLogger(__name__) # This is done this way because WriteType was originally # defined here and in order not to break the API. @@ -864,6 +868,368 @@ def _add_jitter(self, value): return min(max(self.base_delay, delay), self.max_delay) +class ShardConnectionScheduler(ABC): + """ + A base class for a scheduler for a shard connection backoff policy. + ``ShardConnectionScheduler`` is a per Session instance that schedules per shard connections according to + ``ShardConnectionBackoffPolicy`` that instantiates it. + """ + + @abstractmethod + def schedule( + self, + host_id: str, + shard_id: int, + method: Callable[[], None], + ) -> bool: + """ + Schedules a request to create a connection to the given host and shard according to the scheduling policy. + + The `schedule` method is called whenever `HostConnection` needs to establish a connection to the specified + (host_id, shard_id) pair. + + The responsibilities of `schedule` are as follows: + 1. Deduplicate requests for the same (host_id, shard_id), considering both queued and currently running requests. + 2. Ensure that requests are executed at a pace consistent with the expected behavior of the + `ShardConnectionScheduler` implementation. + + The `schedule` method must never execute `method` immediately; `method` should always be run in a separate thread. + Handling of failed `method` executions is managed by upper logic (`HostConnection`) and should not be performed by the + implementation of `schedule`. + + Parameters: + ``host_id`` - an id of the host of the shard. + ``shard_id`` - an id of the shard. + ``method`` - a callable that creates connection and stores it in the connection pool, it handles closed session, + HostConnection and existence of the connection to the shard. + Currently, it is ``HostConnection._open_connection_to_missing_shard``. + + :return: `bool` indicating whether request was scheduled or not. + """ + raise NotImplementedError() + + @abstractmethod + def shutdown(self): + """ + Shutdown the scheduler. + + It should stop the scheduler from execution any creation requests. + Ones that already scheduled should be canceled. + It is acceptable for currently running requests to complete. + """ + raise NotImplementedError() + + +class ShardConnectionBackoffPolicy(ABC): + """ + Base class for shard connection backoff policies. + These policies allow user to control pace of establishing new connections. + """ + + @abstractmethod + def new_connection_scheduler(self, scheduler: _Scheduler) -> ShardConnectionScheduler: + """ + Instantiate a connection scheduler that behaves according to the policy. + + It is called on session initialization. + """ + raise NotImplementedError() + + +class NoDelayShardConnectionBackoffPolicy(ShardConnectionBackoffPolicy): + """ + A shard connection backoff policy with no delay between attempts. + Ensures that at most one pending request connection per (host, shard) pair. + If connection attempts for the same (host, shard) it is silently dropped. + """ + + def new_connection_scheduler(self, scheduler: _Scheduler) -> ShardConnectionScheduler: + return _NoDelayShardConnectionBackoffScheduler(scheduler) + + +class _NoDelayShardConnectionBackoffScheduler(ShardConnectionScheduler): + """ + A scheduler for ``cassandra.policies.NoDelayShardConnectionBackoffPolicy``. + + A shard connection backoff policy with no delay between attempts. + Ensures that at most one pending request connection per (host, shard) pair. + If connection attempts for the same (host, shard) it is silently dropped. + """ + + scheduler: _Scheduler + already_scheduled: set[tuple[str, int]] + lock: Lock + is_shutdown: bool = False + + def __init__(self, scheduler: _Scheduler): + self.scheduler = scheduler + self.already_scheduled = set() + self.lock = Lock() + + def _execute( + self, + host_id: str, + shard_id: int, + method: Callable[[], None], + ) -> None: + if self.is_shutdown: + return + try: + method() + finally: + with self.lock: + self.already_scheduled.remove((host_id, shard_id)) + + def schedule( + self, + host_id: str, + shard_id: int, + method: Callable[[], None], + ) -> bool: + with self.lock: + if self.is_shutdown or (host_id, shard_id) in self.already_scheduled: + return False + self.already_scheduled.add((host_id, shard_id)) + + self.scheduler.schedule(0, self._execute, host_id, shard_id, method) + return True + + def shutdown(self): + with self.lock: + self.is_shutdown = True + + +class ShardConnectionBackoffSchedule(ABC): + @abstractmethod + def new_schedule(self) -> Iterator[float]: + """ + This should return a finite or infinite iterable of delays (each as a + floating point number of seconds). + Note that if the iterable is finite, schedule will be recreated right after iterable is exhausted. + """ + raise NotImplementedError() + + +class ConstantShardConnectionBackoffSchedule(ShardConnectionBackoffSchedule): + """ + A :class:`.ShardConnectionBackoffSchedule` subclass which introduce a constant delay with jitter + between shard connections. + """ + + def __init__(self, delay: float, jitter: float = 0.0): + """ + `delay` should be a floating point number of seconds to wait in-between + each connection attempt. + + `jitter` is a random jitter in seconds. + """ + if delay < 0: + raise ValueError("delay must not be negative") + if jitter < 0: + raise ValueError("jitter must not be negative") + + self.delay = delay + self.jitter = jitter + + def new_schedule(self): + if self.jitter == 0: + yield from repeat(self.delay) + def iterator(): + while True: + yield self.delay + random.uniform(0.0, self.jitter) + return iterator() + + +class LimitedConcurrencyShardConnectionBackoffPolicy(ShardConnectionBackoffPolicy): + """ + A shard connection backoff policy that allows only `max_concurrent` concurrent connections per `host_id`. + + For backoff calculation, it requires either a `cassandra.policies.ShardConnectionBackoffSchedule` or + a `cassandra.policies.ReconnectionPolicy`, as both expose the same API. + + It spawns threads when there are pending requests, maximum number of threads is `max_concurrent` multiplied by nodes in the cluster. + When thread is spawn it initiates backoff schedule, which is local for this thread. + If there are no remaining requests for that `host_id`, thread is killed. + + This policy also prevents multiple pending or scheduled connections for the same (host, shard) pair; + any duplicate attempts to schedule a connection are silently ignored. + """ + backoff_policy: ShardConnectionBackoffSchedule | ReconnectionPolicy + + max_concurrent: int + """ + Max concurrent connection creation requests per scope. + """ + + def __init__( + self, + backoff_policy: ShardConnectionBackoffSchedule | ReconnectionPolicy, + max_concurrent: int = 1, + ): + if not isinstance(backoff_policy, (ShardConnectionBackoffSchedule, ReconnectionPolicy)): + raise ValueError("backoff_policy must be a ShardConnectionBackoffSchedule or ReconnectionPolicy") + if max_concurrent < 1: + raise ValueError("max_concurrent must be a positive integer") + self.backoff_policy = backoff_policy + self.max_concurrent = max_concurrent + + def new_connection_scheduler(self, scheduler: _Scheduler) -> ShardConnectionScheduler: + return _LimitedConcurrencyShardConnectionScheduler(scheduler, self.backoff_policy, self.max_concurrent) + + +class _ScopeBucket: + """ + Holds information for a shard connection backoff policy scope, schedules and executes requests to create connection. + """ + session: _Scheduler + backoff_policy: ShardConnectionBackoffSchedule + lock: Lock + is_shutdown: bool = False + + max_concurrent: int + """ + Max concurrent connection creation requests in the scope. + """ + + currently_pending: int + """ + Number of currently pending connections. + """ + + items: List[Callable[[], None]] + """ + List of scheduled create connections requests. + """ + + def __init__( + self, + scheduler: _Scheduler, + backoff_policy: ShardConnectionBackoffSchedule, + max_concurrent: int, + ): + self.items = [] + self.scheduler = scheduler + self.backoff_policy = backoff_policy + self.lock = Lock() + self.max_concurrent = max_concurrent + self.currently_pending = 0 + + def _get_delay(self, schedule: Iterator[float]) -> Tuple[Iterator[float], float]: + try: + return schedule, next(schedule) + except StopIteration: + # A bit of trickery to avoid having lock around self.schedule + schedule = self.backoff_policy.new_schedule() + delay = next(schedule) + self.schedule = schedule + return schedule, delay + + def _run(self, schedule: Iterator[float]): + if self.is_shutdown: + return + + with self.lock: + try: + request = self.items.pop(0) + except IndexError: + # Just in case + if self.currently_pending > 0: + self.currently_pending -= 1 + # When items are exhausted reset schedule to ensure that new items going to get another schedule + # It is important for exponential policy + return + + try: + request() + finally: + schedule, delay = self._get_delay(schedule) + self.scheduler.schedule(delay, self._run, schedule) + + def schedule_new_connection(self, cb: Callable[[], None]): + with self.lock: + if self.is_shutdown: + return + self.items.append(cb) + if self.currently_pending < self.max_concurrent: + self.currently_pending += 1 + schedule = self.backoff_policy.new_schedule() + delay = next(schedule) + self.scheduler.schedule(delay, self._run, schedule) + + def shutdown(self): + with self.lock: + self.is_shutdown = True + + +class _LimitedConcurrencyShardConnectionScheduler(ShardConnectionScheduler): + """ + A scheduler for ``cassandra.policies.LimitedConcurrencyShardConnectionPolicy``. + + Limits concurrency for connection creation requests to ``max_concurrent`` per host_id. + """ + + already_scheduled: set[tuple[str, int]] + """ + Set of (host_id, shard_id) of scheduled or pending requests. + """ + + per_host_scope: dict[str, _ScopeBucket] + """ + Scopes storage, key is host_id, value is an instance that holds scope data. + """ + + backoff_policy: ShardConnectionBackoffSchedule + scheduler: _Scheduler + lock: Lock + is_shutdown: bool = False + + max_concurrent: int + """ + Max concurrent connection creation requests per host_id. + """ + + def __init__( + self, + scheduler: _Scheduler, + backoff_policy: ShardConnectionBackoffSchedule, + max_concurrent: int, + ): + self.already_scheduled = set() + self.per_host_scope = {} + self.backoff_policy = backoff_policy + self.max_concurrent = max_concurrent + self.scheduler = scheduler + self.lock = Lock() + + def _execute(self, host_id: str, shard_id: int, method: Callable[[], None]): + if self.is_shutdown: + return + try: + method() + finally: + with self.lock: + self.already_scheduled.remove((host_id, shard_id)) + + def schedule(self, host_id: str, shard_id: int, method: Callable[[], None]) -> bool: + with self.lock: + if self.is_shutdown or (host_id, shard_id) in self.already_scheduled: + return False + self.already_scheduled.add((host_id, shard_id)) + + scope_info = self.per_host_scope.get(host_id) + if not scope_info: + scope_info = _ScopeBucket(self.scheduler, self.backoff_policy, self.max_concurrent) + self.per_host_scope[host_id] = scope_info + scope_info.schedule_new_connection(partial(self._execute, host_id, shard_id, method)) + return True + + def shutdown(self): + with self.lock: + self.is_shutdown = True + for scope in self.per_host_scope.values(): + scope.shutdown() + + class RetryPolicy(object): """ A policy that describes whether to retry, rethrow, or ignore coordinator diff --git a/cassandra/pool.py b/cassandra/pool.py index d1f6604abf..c5922c1bbf 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -16,7 +16,7 @@ Connection pooling and host management. """ from concurrent.futures import Future -from functools import total_ordering +from functools import total_ordering, partial import logging import socket import time @@ -402,7 +402,6 @@ def __init__(self, host, host_distance, session): # this is used in conjunction with the connection streams. Not using the connection lock because the connection can be replaced in the lifetime of the pool. self._stream_available_condition = Condition(Lock()) self._is_replacing = False - self._connecting = set() self._connections = {} self._pending_connections = [] # A pool of additional connections which are not used but affect how Scylla @@ -418,7 +417,6 @@ def __init__(self, host, host_distance, session): # and are waiting until all requests time out or complete # so that we can dispose of them. self._trash = set() - self._shard_connections_futures = [] self.advanced_shardaware_block_until = 0 if host_distance == HostDistance.IGNORED: @@ -483,25 +481,25 @@ def _get_connection_for_routing_key(self, routing_key=None, keyspace=None, table self.host, routing_key ) - if conn.orphaned_threshold_reached and shard_id not in self._connecting: + if conn.orphaned_threshold_reached: # The connection has met its orphaned stream ID limit # and needs to be replaced. Start opening a connection # to the same shard and replace when it is opened. - self._connecting.add(shard_id) - self._session.submit(self._open_connection_to_missing_shard, shard_id) + self._session.shard_connection_backoff_scheduler.schedule( + self.host.host_id, shard_id, partial(self._open_connection_to_missing_shard, shard_id)) log.debug( - "Connection to shard_id=%i reached orphaned stream limit, replacing on host %s (%s/%i)", + "Scheduling Connection to shard_id=%i reached orphaned stream limit, replacing on host %s (%s/%i)", shard_id, self.host, len(self._connections.keys()), self.host.sharding_info.shards_count ) - elif shard_id not in self._connecting: + else: # rate controlled optimistic attempt to connect to a missing shard - self._connecting.add(shard_id) - self._session.submit(self._open_connection_to_missing_shard, shard_id) + self._session.shard_connection_backoff_scheduler.schedule( + self.host.host_id, shard_id, partial(self._open_connection_to_missing_shard, shard_id)) log.debug( - "Trying to connect to missing shard_id=%i on host %s (%s/%i)", + "Scheduling connection to missing shard_id=%i on host %s (%s/%i)", shard_id, self.host, len(self._connections.keys()), @@ -609,8 +607,8 @@ def _replace(self, connection): if connection.features.shard_id in self._connections.keys(): del self._connections[connection.features.shard_id] if self.host.sharding_info and not self._session.cluster.shard_aware_options.disable: - self._connecting.add(connection.features.shard_id) - self._session.submit(self._open_connection_to_missing_shard, connection.features.shard_id) + self._session.shard_connection_backoff_scheduler.schedule( + self.host.host_id, connection.features.shard_id, partial(self._open_connection_to_missing_shard, connection.features.shard_id)) else: connection = self._session.cluster.connection_factory(self.host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released) @@ -635,9 +633,6 @@ def shutdown(self): with self._stream_available_condition: self._stream_available_condition.notify_all() - for future in self._shard_connections_futures: - future.cancel() - connections_to_close = self._connections.copy() pending_connections_to_close = self._pending_connections.copy() self._connections.clear() @@ -843,7 +838,6 @@ def _open_connection_to_missing_shard(self, shard_id): self._excess_connections.add(conn) if close_connection: conn.close() - self._connecting.discard(shard_id) def _open_connections_for_all_shards(self, skip_shard_id=None): """ @@ -856,10 +850,8 @@ def _open_connections_for_all_shards(self, skip_shard_id=None): for shard_id in range(self.host.sharding_info.shards_count): if skip_shard_id is not None and skip_shard_id == shard_id: continue - future = self._session.submit(self._open_connection_to_missing_shard, shard_id) - if isinstance(future, Future): - self._connecting.add(shard_id) - self._shard_connections_futures.append(future) + self._session.shard_connection_backoff_scheduler.schedule( + self.host.host_id, shard_id, partial(self._open_connection_to_missing_shard, shard_id)) trash_conns = None with self._lock: diff --git a/tests/integration/long/test_policies.py b/tests/integration/long/test_policies.py index 33f35ced0d..11e5d1d758 100644 --- a/tests/integration/long/test_policies.py +++ b/tests/integration/long/test_policies.py @@ -11,16 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import inspect +import os +import time import unittest +from typing import Optional +from unittest.mock import Mock from cassandra import ConsistencyLevel, Unavailable -from cassandra.cluster import ExecutionProfile, EXEC_PROFILE_DEFAULT +from cassandra.cluster import ExecutionProfile, EXEC_PROFILE_DEFAULT, Session +from cassandra.policies import LimitedConcurrencyShardConnectionBackoffPolicy, ConstantReconnectionPolicy, \ + ShardConnectionBackoffPolicy, NoDelayShardConnectionBackoffPolicy, _ScopeBucket, \ + _NoDelayShardConnectionBackoffScheduler +from cassandra.shard_info import _ShardingInfo from tests.integration import use_cluster, get_cluster, get_node, TestCluster def setup_module(): + os.environ['SCYLLA_EXT_OPTS'] = "--smp 4" use_cluster('test_cluster', [4]) @@ -65,3 +74,102 @@ def test_should_rethrow_on_unvailable_with_default_policy_if_cas(self): self.assertEqual(exception.consistency, ConsistencyLevel.SERIAL) self.assertEqual(exception.required_replicas, 2) self.assertEqual(exception.alive_replicas, 1) + + +class ShardBackoffPolicyTests(unittest.TestCase): + @classmethod + def tearDownClass(cls): + cluster = get_cluster() + cluster.start(wait_for_binary_proto=True, wait_other_notice=True) # make sure other nodes are restarted + + def test_limited_concurrency_1_connection_per_host(self): + self._test_backoff( + LimitedConcurrencyShardConnectionBackoffPolicy( + backoff_policy=ConstantReconnectionPolicy(0.1), + max_concurrent=1, + ) + ) + + def test_limited_concurrency_2_connection_per_host(self): + self._test_backoff( + LimitedConcurrencyShardConnectionBackoffPolicy( + backoff_policy=ConstantReconnectionPolicy(0.1), + max_concurrent=1, + ) + ) + + def test_no_delay(self): + self._test_backoff(NoDelayShardConnectionBackoffPolicy()) + + def _test_backoff(self, shard_connection_backoff_policy: ShardConnectionBackoffPolicy): + backoff_policy = None + if isinstance(shard_connection_backoff_policy, LimitedConcurrencyShardConnectionBackoffPolicy): + backoff_policy = shard_connection_backoff_policy.backoff_policy + + cluster = TestCluster( + shard_connection_backoff_policy=shard_connection_backoff_policy, + reconnection_policy=ConstantReconnectionPolicy(0), + ) + + # Collect scheduled calls and execute them right away + scheduler_calls = [] + original_schedule = cluster.scheduler.schedule + pending = 0 + + def new_schedule(delay, fn, *args, **kwargs): + nonlocal pending + pending+=1 + + def execute(): + nonlocal pending + try: + fn(*args, **kwargs) + finally: + pending-=1 + + scheduler_calls.append((delay, fn, args, kwargs)) + return original_schedule(0, execute) + + cluster.scheduler.schedule = Mock(side_effect=new_schedule) + + session = cluster.connect() + sharding_info = get_sharding_info(session) + + # Since scheduled calls executed in a separate thread we need to give them some time to complete + while pending > 0: + time.sleep(0.01) + + if not sharding_info: + # If it is not scylla `ShardConnectionBackoffScheduler` should not be involved + for delay, fn, args, kwargs in scheduler_calls: + if fn.__self__.__class__ is _ScopeBucket or fn.__self__.__class__ is _NoDelayShardConnectionBackoffScheduler: + self.fail( + "in non-shard-aware case connection should be created directly, not involving ShardConnectionBackoffScheduler") + return + + sleep_time = 0 + if backoff_policy: + schedule = backoff_policy.new_schedule() + sleep_time = next(iter(schedule)) + + # Make sure that all scheduled calls have delay according to policy + found_related_calls = 0 + for delay, fn, args, kwargs in scheduler_calls: + if fn.__self__.__class__ is _ScopeBucket or fn.__self__.__class__ is _NoDelayShardConnectionBackoffScheduler: + found_related_calls += 1 + self.assertEqual(delay, sleep_time) + self.assertLessEqual(len(session.hosts) * (sharding_info.shards_count - 1), found_related_calls) + + +def get_connections_per_host(session: Session) -> dict[str, int]: + host_connections = {} + for host, pool in session._pools.items(): + host_connections[host.host_id] = len(pool._connections) + return host_connections + + +def get_sharding_info(session: Session) -> Optional[_ShardingInfo]: + for host in session.hosts: + if host.sharding_info: + return host.sharding_info + return None diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index ec7d51bc2d..eb1690fec1 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -148,8 +148,8 @@ def test_event_delay_timing(self, *_): PYTHON-473 """ sched = _Scheduler(None) - sched.schedule(0, lambda: None) - sched.schedule(0, lambda: None) # pre-473: "TypeError: unorderable types: function() < function()"t + sched.schedule(1, lambda: None) + sched.schedule(1, lambda: None) # pre-473: "TypeError: unorderable types: function() < function()"t class SessionTest(unittest.TestCase): diff --git a/tests/unit/test_host_connection_pool.py b/tests/unit/test_host_connection_pool.py index b4cb067d2f..6d596cc0c6 100644 --- a/tests/unit/test_host_connection_pool.py +++ b/tests/unit/test_host_connection_pool.py @@ -22,15 +22,23 @@ from threading import Thread, Event, Lock from unittest.mock import Mock, NonCallableMagicMock, MagicMock -from cassandra.cluster import Session, ShardAwareOptions +from cassandra.cluster import Session, ShardAwareOptions, _Scheduler from cassandra.connection import Connection from cassandra.pool import HostConnection, HostConnectionPool from cassandra.pool import Host, NoConnectionsAvailable -from cassandra.policies import HostDistance, SimpleConvictionPolicy +from cassandra.policies import HostDistance, SimpleConvictionPolicy, _NoDelayShardConnectionBackoffScheduler LOGGER = logging.getLogger(__name__) +class FakeScheduler(_Scheduler): + def __init__(self): + super(FakeScheduler, self).__init__(ThreadPoolExecutor()) + + def schedule(self, delay, fn, *args, **kwargs): + super().schedule(0, fn, *args, **kwargs) + + class _PoolTests(unittest.TestCase): __test__ = False PoolImpl = None @@ -41,6 +49,9 @@ def make_session(self): session.cluster.get_core_connections_per_host.return_value = 1 session.cluster.get_max_requests_per_connection.return_value = 1 session.cluster.get_max_connections_per_host.return_value = 1 + session.shard_connection_backoff_scheduler = _NoDelayShardConnectionBackoffScheduler(FakeScheduler()) + session.shard_connection_backoff_scheduler.schedule = Mock(wraps=session.shard_connection_backoff_scheduler.schedule) + session.is_shutdown = False return session def test_borrow_and_return(self): @@ -174,9 +185,9 @@ def test_return_defunct_connection_on_down_host(self): if self.PoolImpl is HostConnection: # on shard aware implementation we use submit function regardless self.assertTrue(host.signal_connection_failure.call_args) - self.assertTrue(session.submit.called) + self.assertTrue(session.shard_connection_backoff_scheduler.schedule.called) else: - self.assertFalse(session.submit.called) + self.assertFalse(session.shard_connection_backoff_scheduler.schedule.called) self.assertTrue(session.cluster.signal_connection_failure.call_args) self.assertTrue(pool.is_shutdown) diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index e7757aedfc..f89d1b4348 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -13,6 +13,7 @@ # limitations under the License. import unittest +from functools import partial from itertools import islice, cycle from unittest.mock import Mock, patch, call @@ -26,13 +27,16 @@ from cassandra import ConsistencyLevel from cassandra.cluster import Cluster, ControlConnection from cassandra.metadata import Metadata -from cassandra.policies import (RackAwareRoundRobinPolicy, RoundRobinPolicy, WhiteListRoundRobinPolicy, DCAwareRoundRobinPolicy, +from cassandra.policies import (RackAwareRoundRobinPolicy, RoundRobinPolicy, WhiteListRoundRobinPolicy, + DCAwareRoundRobinPolicy, TokenAwarePolicy, SimpleConvictionPolicy, HostDistance, ExponentialReconnectionPolicy, RetryPolicy, WriteType, DowngradingConsistencyRetryPolicy, ConstantReconnectionPolicy, LoadBalancingPolicy, ConvictionPolicy, ReconnectionPolicy, FallthroughRetryPolicy, - IdentityTranslator, EC2MultiRegionTranslator, HostFilterPolicy, ExponentialBackoffRetryPolicy) + IdentityTranslator, EC2MultiRegionTranslator, HostFilterPolicy, + ExponentialBackoffRetryPolicy, _ScopeBucket, _LimitedConcurrencyShardConnectionScheduler, + _NoDelayShardConnectionBackoffScheduler) from cassandra.connection import DefaultEndPoint, UnixSocketEndPoint from cassandra.pool import Host from cassandra.query import Statement @@ -1579,3 +1583,236 @@ def test_create_whitelist(self): # Only the filtered replicas should be allowed self.assertEqual(set(query_plan), {Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy), Host(DefaultEndPoint("127.0.0.4"), SimpleConvictionPolicy)}) + + +class MockScheduler: + def __init__(self): + self.requests = [] + + def schedule(self, delay, fn, *args, **kwargs): + self.requests.append((delay, fn, args, kwargs)) + + def execute(self): + old_requests = self.requests.copy() + self.requests = [] + for delay, fn, args, kwargs in old_requests: + fn(*args, **kwargs) + + +class NoDelayShardConnectionBackoffSchedulerTests(unittest.TestCase): + def test_schedule_executes_method_immediately(self): + method = Mock() + scheduler = MockScheduler() + policy = _NoDelayShardConnectionBackoffScheduler(scheduler) + + self.assertTrue(policy.schedule('host1', 0, partial(method, 1, 2, key='val'))) + + self.assertEqual(scheduler.requests[0][0], 0) + scheduler.execute() + + method.assert_called_once_with(1, 2, key='val') + + def test_schedule_skips_if_host_shard_already_scheduled(self): + method = Mock() + scheduler = MockScheduler() + policy = _NoDelayShardConnectionBackoffScheduler(scheduler) + + self.assertTrue(policy.schedule('host1', 0, method)) + self.assertFalse(policy.schedule('host1', 0, method)) + + self.assertEqual(len(scheduler.requests), 1) + scheduler.execute() + method.assert_called_once() + + def test_schedule_does_not_skip_if_shard_is_different(self): + method = Mock() + scheduler = MockScheduler() + policy = _NoDelayShardConnectionBackoffScheduler(scheduler) + + self.assertTrue(policy.schedule('host1', 0, method)) + self.assertTrue(policy.schedule('host1', 1, method)) + + self.assertEqual(len(scheduler.requests), 2) + scheduler.execute() + + self.assertEqual(method.call_count, 2) + + def test_already_scheduled_resets_after_execution(self): + method = Mock() + scheduler = MockScheduler() + policy = _NoDelayShardConnectionBackoffScheduler(scheduler) + self.assertTrue(policy.schedule('host1', 0, method)) + + scheduler.execute() + + self.assertTrue(policy.schedule('host1', 0, method)) + + scheduler.execute() + + self.assertEqual(method.call_count, 2) + + def test_schedule_skips_if_shutdown(self): + method = Mock() + scheduler = MockScheduler() + policy = _NoDelayShardConnectionBackoffScheduler(scheduler) + policy.shutdown() + + policy.schedule('host1', 0, method) + + self.assertEqual(len(scheduler.requests), 0) + + +class ScopeBucketTests(unittest.TestCase): + def setUp(self): + self.reconnection_policy = Mock() + self.schedule = [0.1, 0.2, 0.3] + self.reconnection_policy.new_schedule.side_effect = lambda: iter(self.schedule) + + def test_add_schedules_initial_task(self): + method = Mock() + scheduler = MockScheduler() + bucket = _ScopeBucket(scheduler, self.reconnection_policy, 1) + bucket.schedule_new_connection(partial(method, 1, x=2)) + self.assertEqual(bucket.currently_pending, 1) + self.assertEqual(len(scheduler.requests), 1) + self.assertEqual(scheduler.requests[0][0], 0.1) + scheduler.execute() + method.assert_called_once() + + def test_max_concurrent_1(self): + self._test_multiple_adds_only_schedule_once(1) + + def test_max_concurrent_2(self): + self._test_multiple_adds_only_schedule_once(2) + + def _test_multiple_adds_only_schedule_once(self, max_concurrent): + scheduler = MockScheduler() + bucket = _ScopeBucket(scheduler, self.reconnection_policy, max_concurrent) + method1 = Mock() + method2 = Mock() + + bucket.schedule_new_connection(partial(method1, "a")) + bucket.schedule_new_connection(partial(method2, "b")) + # Only one schedule should be triggered + + self.assertEqual(len(scheduler.requests), max_concurrent) + + schedule = iter(self.schedule) + delay = next(schedule) + self.assertEqual(scheduler.requests[0][0], delay) + if max_concurrent == 2: + self.assertEqual(scheduler.requests[1][0], delay) + + # Both methods are enqueued + self.assertEqual(len(bucket.items), 2) + self.assertEqual(bucket.currently_pending, max_concurrent) + scheduler.execute() + if max_concurrent == 2: + self.assertEqual(len(bucket.items), 0) + self.assertEqual(bucket.currently_pending, 2) + scheduler.execute() + self.assertEqual(bucket.currently_pending, 0) + return + + self.assertEqual(len(bucket.items), 1) + self.assertEqual(bucket.currently_pending, 1) + delay = next(schedule) + self.assertEqual(scheduler.requests[0][0], delay) + + scheduler.execute() + self.assertEqual(len(bucket.items), 0) + self.assertEqual(bucket.currently_pending, 1) + delay = next(schedule) + self.assertEqual(scheduler.requests[0][0], delay) + + scheduler.execute() + self.assertEqual(bucket.currently_pending, 0) + + def test_does_not_schedule_if_shutdown(self): + scheduler = MockScheduler() + bucket = _ScopeBucket(scheduler, self.reconnection_policy, 2) + bucket.shutdown() + method = Mock() + + bucket.schedule_new_connection(partial(method)) + self.assertEqual(len(scheduler.requests), 0) + + def test_get_delay_resets_schedule_on_stopiteration(self): + scheduler = MockScheduler() + bucket = _ScopeBucket(scheduler, self.reconnection_policy, 1) + + method = Mock() + + for delay in self.schedule: + bucket.schedule_new_connection(partial(method)) + self.assertEqual(scheduler.requests[0][0], delay) + scheduler.execute() + + # _ScopeBucket has to reset it's schedule because it is exhausted + for delay in self.schedule: + bucket.schedule_new_connection(partial(method)) + self.assertEqual(scheduler.requests[0][0], delay) + scheduler.execute() + + +class LimitedConcurrencyShardConnectionSchedulerTests(unittest.TestCase): + def setUp(self): + self.mock_scheduler = Mock() + + self.reconnection_policy = Mock() + self.reconnection_policy.new_schedule.return_value = cycle([0]) + + self.method = Mock() + self.host_id = 'host123' + self.shard_id = 0 + + def test_schedules_once_per_key(self): + scheduler = _LimitedConcurrencyShardConnectionScheduler( + self.mock_scheduler, self.reconnection_policy, 1 + ) + + scheduled = scheduler.schedule(self.host_id, self.shard_id, self.method) + self.assertTrue(scheduled) + # Try to schedule again for same key: should be rejected + scheduled2 = scheduler.schedule(self.host_id, self.shard_id, self.method) + self.assertFalse(scheduled2) + + # _ScopeBucket should have been created for cluster scope + self.assertEqual(len(scheduler.per_host_scope), 1) + scope = next(iter(scheduler.per_host_scope.values())) + self.assertIsNotNone(scope) + self.assertEqual(len(scope.items), 1) + + def test_schedule_separate_keys(self): + scheduler = _LimitedConcurrencyShardConnectionScheduler( + self.mock_scheduler, self.reconnection_policy, 1 + ) + + scheduled1 = scheduler.schedule('host1', 1, self.method) + scheduled2 = scheduler.schedule('host1', 2, self.method) + scheduled3 = scheduler.schedule('host2', 1, self.method) + + self.assertTrue(scheduled1) + self.assertTrue(scheduled2) + self.assertTrue(scheduled3) + + # Should create scopes for both hosts + self.assertIn('host1', scheduler.per_host_scope) + self.assertIn('host2', scheduler.per_host_scope) + self.assertEqual(len(scheduler.per_host_scope['host1'].items), 2) + self.assertEqual(len(scheduler.per_host_scope['host2'].items), 1) + + def test_execute_resets_already_scheduled_flag(self): + scheduler = _LimitedConcurrencyShardConnectionScheduler( + self.mock_scheduler, self.reconnection_policy, 1 + ) + + self.assertTrue(scheduler.schedule(self.host_id, self.shard_id, self.method)) + + # Simulate running the scheduled task manually + self.assertTrue((self.host_id, self.shard_id) in scheduler.already_scheduled) + scheduler._execute(self.host_id, self.shard_id, self.method) + + # Should now be marked not scheduled + self.assertFalse((self.host_id, self.shard_id) in scheduler.already_scheduled) + self.method.assert_called_once() diff --git a/tests/unit/test_shard_aware.py b/tests/unit/test_shard_aware.py index fe7b95edba..fd1d42aab9 100644 --- a/tests/unit/test_shard_aware.py +++ b/tests/unit/test_shard_aware.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import uuid +from unittest.mock import Mock try: import unittest2 as unittest @@ -21,7 +23,10 @@ from mock import MagicMock from concurrent.futures import ThreadPoolExecutor -from cassandra.cluster import ShardAwareOptions +from cassandra.cluster import ShardAwareOptions, _Scheduler +from cassandra.policies import ConstantReconnectionPolicy, \ + NoDelayShardConnectionBackoffPolicy, LimitedConcurrencyShardConnectionBackoffPolicy, _ScopeBucket, \ + _NoDelayShardConnectionBackoffScheduler from cassandra.pool import HostConnection, HostDistance from cassandra.connection import ShardingInfo, DefaultEndPoint from cassandra.metadata import Murmur3Token @@ -53,11 +58,35 @@ class OptionsHolder(object): self.assertEqual(shard_info.shard_id_from_token(Murmur3Token.from_key(b"e").value), 4) self.assertEqual(shard_info.shard_id_from_token(Murmur3Token.from_key(b"100000").value), 2) - def test_advanced_shard_aware_port(self): + def test_shard_aware_reconnection_policy_no_delay(self): + # with NoDelayReconnectionPolicy all the connections should be created right away + self._test_shard_aware_reconnection_policy(4, NoDelayShardConnectionBackoffPolicy(), 4) + + def test_shard_aware_reconnection_policy_delay(self): + # with ConstantReconnectionPolicy first connection is created right away, others are delayed + self._test_shard_aware_reconnection_policy( + 4, + LimitedConcurrencyShardConnectionBackoffPolicy( + ConstantReconnectionPolicy(0.1), + 1 + ), 4) + + def test_shard_aware_reconnection_policy_delay_non_scylla(self): + self._test_shard_aware_reconnection_policy( + 0, + LimitedConcurrencyShardConnectionBackoffPolicy( + ConstantReconnectionPolicy(0.1), + 1 + ), 1) + + def _test_shard_aware_reconnection_policy(self, shard_count, shard_connection_backoff_policy, expected_connections): """ Test that on given a `shard_aware_port` on the OPTIONS message (ShardInfo class) - the next connections would be open using this port + It checks that: + 1. Next connections are opened using this port + 2. Connection creation pase matches `shard_connection_backoff_policy` """ + class MockSession(MagicMock): is_shutdown = False keyspace = "ks1" @@ -71,17 +100,36 @@ def __init__(self, is_ssl=False, *args, **kwargs): self.cluster.ssl_options = None self.cluster.shard_aware_options = ShardAwareOptions() self.cluster.executor = ThreadPoolExecutor(max_workers=2) + self._executor_submit_original = self.cluster.executor.submit + self.cluster.executor.submit = self._executor_submit + self.cluster.scheduler = _Scheduler(self.cluster.executor) + + # Collect scheduled calls and execute them right away + self.scheduler_calls = [] + original_schedule = self.cluster.scheduler.schedule + + def new_schedule(delay, fn, *args, **kwargs): + self.scheduler_calls.append((delay, fn, args, kwargs)) + return original_schedule(0, fn, *args, **kwargs) + + self.cluster.scheduler.schedule = Mock(side_effect=new_schedule) self.cluster.signal_connection_failure = lambda *args, **kwargs: False self.cluster.connection_factory = self.mock_connection_factory self.connection_counter = 0 + self.shard_connection_backoff_scheduler = shard_connection_backoff_policy.new_connection_scheduler( + self.cluster.scheduler) self.futures = [] def submit(self, fn, *args, **kwargs): + if self.is_shutdown: + return None + return self.cluster.executor.submit(fn, *args, **kwargs) + + def _executor_submit(self, fn, *args, **kwargs): logging.info("Scheduling %s with args: %s, kwargs: %s", fn, args, kwargs) - if not self.is_shutdown: - f = self.cluster.executor.submit(fn, *args, **kwargs) - self.futures += [f] - return f + f = self._executor_submit_original(fn, *args, **kwargs) + self.futures += [f] + return f def mock_connection_factory(self, *args, **kwargs): connection = MagicMock() @@ -89,27 +137,52 @@ def mock_connection_factory(self, *args, **kwargs): connection.is_defunct = False connection.is_closed = False connection.orphaned_threshold_reached = False - connection.endpoint = args[0] - sharding_info = ShardingInfo(shard_id=1, shards_count=4, partitioner="", sharding_algorithm="", sharding_ignore_msb=0, shard_aware_port=19042, shard_aware_port_ssl=19045) - connection.features = ProtocolFeatures(shard_id=kwargs.get('shard_id', self.connection_counter), sharding_info=sharding_info) + connection.endpoint = args[0] + sharding_info = None + if shard_count: + sharding_info = ShardingInfo(shard_id=1, shards_count=shard_count, partitioner="", + sharding_algorithm="", sharding_ignore_msb=0, shard_aware_port=19042, + shard_aware_port_ssl=19045) + connection.features = ProtocolFeatures( + shard_id=kwargs.get('shard_id', self.connection_counter), + sharding_info=sharding_info) self.connection_counter += 1 return connection host = MagicMock() + host.host_id = uuid.uuid4() host.endpoint = DefaultEndPoint("1.2.3.4") + session = None + backoff_policy = None + if isinstance(shard_connection_backoff_policy, LimitedConcurrencyShardConnectionBackoffPolicy): + backoff_policy = shard_connection_backoff_policy.backoff_policy - for port, is_ssl in [(19042, False), (19045, True)]: - session = MockSession(is_ssl=is_ssl) - pool = HostConnection(host=host, host_distance=HostDistance.REMOTE, session=session) - for f in session.futures: - f.result() - assert len(pool._connections) == 4 - for shard_id, connection in pool._connections.items(): - assert connection.features.shard_id == shard_id - if shard_id == 0: - assert connection.endpoint == DefaultEndPoint("1.2.3.4") - else: - assert connection.endpoint == DefaultEndPoint("1.2.3.4", port=port) + try: + for port, is_ssl in [(19042, False), (19045, True)]: + session = MockSession(is_ssl=is_ssl) + pool = HostConnection(host=host, host_distance=HostDistance.REMOTE, session=session) + for f in session.futures: + f.result() + assert len(pool._connections) == expected_connections + for shard_id, connection in pool._connections.items(): + assert connection.features.shard_id == shard_id + if shard_id == 0: + assert connection.endpoint == DefaultEndPoint("1.2.3.4") + else: + assert connection.endpoint == DefaultEndPoint("1.2.3.4", port=port) + + sleep_time = 0 + if backoff_policy: + sleep_time = next(iter(backoff_policy.new_schedule())) - session.cluster.executor.shutdown(wait=True) + found_related_calls = 0 + for delay, fn, args, kwargs in session.scheduler_calls: + if fn.__self__.__class__ in (_ScopeBucket, _NoDelayShardConnectionBackoffScheduler): + found_related_calls += 1 + self.assertEqual(delay, sleep_time) + self.assertLessEqual(shard_count - 1, found_related_calls) + finally: + if session: + session.cluster.scheduler.shutdown() + session.cluster.executor.shutdown(wait=True)