Skip to content

Commit

Permalink
fix: support both types of cn endpoint patterns (#701)
Browse files Browse the repository at this point in the history
  • Loading branch information
karenc-bq authored Oct 18, 2024
1 parent d0bce5f commit eeebefa
Show file tree
Hide file tree
Showing 9 changed files with 267 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,12 @@ def populate_opened_connection_set(self, host_info: HostInfo, conn: Connection):
"""

aliases: FrozenSet[str] = host_info.as_aliases()
host: str = host_info.as_alias()

if self._rds_utils.is_rds_instance(host):
self._track_connection(host, conn)
if self._rds_utils.is_rds_instance(host_info.host):
self._track_connection(host_info.as_alias(), conn)
return

instance_endpoint: Optional[str] = next((alias for alias in aliases if self._rds_utils.is_rds_instance(alias)),
instance_endpoint: Optional[str] = next((alias for alias in aliases if self._rds_utils.is_rds_instance(self._rds_utils.remove_port(alias))),
None)
if not instance_endpoint:
logger.debug("OpenedConnectionTracker.UnableToPopulateOpenedConnectionSet")
Expand All @@ -82,7 +81,7 @@ def invalidate_all_connections(self, host_info: Optional[HostInfo] = None, host:
return

for instance in host:
if instance is not None and self._rds_utils.is_rds_instance(instance):
if instance is not None and self._rds_utils.is_rds_instance(self._rds_utils.remove_port(instance)):
instance_endpoint = instance
break

Expand Down
9 changes: 6 additions & 3 deletions aws_advanced_python_wrapper/host_list_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@ def _initialize(self):
else:
self._cluster_instance_template = HostInfo(
host=self._rds_utils.get_rds_instance_host_pattern(self._initial_host_info.host),
host_id=self._initial_host_info.host_id,
port=self._initial_host_info.port,
host_availability_strategy=host_availability_strategy)
self._validate_host_pattern(self._cluster_instance_template.host)

Expand All @@ -216,14 +218,15 @@ def _initialize(self):
self._cluster_id = cluster_id_suggestion.cluster_id
self._is_primary_cluster_id = cluster_id_suggestion.is_primary_cluster_id
else:
cluster_url = self._rds_utils.get_rds_cluster_host_url(self._initial_host_info.url)
cluster_url = self._rds_utils.get_rds_cluster_host_url(self._initial_host_info.host)
if cluster_url is not None:
self._cluster_id = cluster_url
self._cluster_id = f"{cluster_url}:{self._cluster_instance_template.port}" \
if self._cluster_instance_template.is_port_specified() else cluster_url
self._is_primary_cluster_id = True
self._is_primary_cluster_id_cache.put(self._cluster_id, True,
self._suggested_cluster_id_refresh_ns)

self._is_initialized = True
self._is_initialized = True

def _validate_host_pattern(self, host: str):
if not self._rds_utils.is_dns_pattern_valid(host):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,12 @@ def __init__(
self,
pool_configurator: Optional[Callable] = None,
pool_mapping: Optional[Callable] = None,
accept_url_func: Optional[Callable] = None,
pool_expiration_check_ns: int = -1,
pool_cleanup_interval_ns: int = -1):
self._pool_configurator = pool_configurator
self._pool_mapping = pool_mapping
self._accept_url_func = accept_url_func

if pool_expiration_check_ns > -1:
SqlAlchemyPooledConnectionProvider._POOL_EXPIRATION_CHECK_NS = pool_expiration_check_ns
Expand All @@ -80,6 +82,8 @@ def keys(self):
return self._database_pools.keys()

def accepts_host_info(self, host_info: HostInfo, props: Properties) -> bool:
if self._accept_url_func:
return self._accept_url_func(host_info, props)
url_type = SqlAlchemyPooledConnectionProvider._rds_utils.identify_rds_type(host_info.host)
return RdsUrlType.RDS_INSTANCE == url_type

Expand Down
204 changes: 129 additions & 75 deletions aws_advanced_python_wrapper/utils/rdsutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from re import search, sub
from typing import Optional
from __future__ import annotations

from re import Match, search, sub
from typing import Dict, Optional

from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType

Expand Down Expand Up @@ -58,135 +60,156 @@ class RdsUtils:
Example: test-postgres-instance-1.123456789012.rds.cn-northwest-1.amazonaws.com.cn
"""

AURORA_DNS_PATTERN = r"(?P<instance>.+)\." \
r"(?P<dns>proxy-|cluster-|cluster-ro-|cluster-custom-)?" \
AURORA_DNS_PATTERN = r"^(?P<instance>.+)\." \
r"(?P<dns>proxy-|cluster-|cluster-ro-|cluster-custom-|limitless-)?" \
r"(?P<domain>[a-zA-Z0-9]+\." \
r"(?P<region>[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com)(?!\.cn$)"
AURORA_INSTANCE_PATTERN = r"(?P<instance>.+)\." \
r"(?P<region>[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com)(?!\.cn)$"
AURORA_INSTANCE_PATTERN = r"^(?P<instance>.+)\." \
r"(?P<domain>[a-zA-Z0-9]+\." \
r"(?P<region>[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com)(?!\.cn$)"
AURORA_CLUSTER_PATTERN = r"(?P<instance>.+)\." \
r"(?P<region>[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com)(?!\.cn)$"
AURORA_CLUSTER_PATTERN = r"^(?P<instance>.+)\." \
r"(?P<dns>cluster-|cluster-ro-)+" \
r"(?P<domain>[a-zA-Z0-9]+\." \
r"(?P<region>[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com)(?!\.cn$)"
AURORA_CUSTOM_CLUSTER_PATTERN = r"(?P<instance>.+)\." \
r"(?P<region>[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com)(?!\.cn)$"
AURORA_CUSTOM_CLUSTER_PATTERN = r"^(?P<instance>.+)\." \
r"(?P<dns>cluster-custom-)+" \
r"(?P<domain>[a-zA-Z0-9]+\." \
r"(?P<region>[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com)(?!\.cn$)"
AURORA_PROXY_DNS_PATTERN = r"(?P<instance>.+)\." \
r"(?P<region>[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com)(?!\.cn)$"
AURORA_PROXY_DNS_PATTERN = r"^(?P<instance>.+)\." \
r"(?P<dns>proxy-)+" \
r"(?P<domain>[a-zA-Z0-9]+\." \
r"(?P<region>[a-zA-Z0-9\\-]+)\.rds\.amazonaws\.com)(?!\.cn$)"
AURORA_CHINA_DNS_PATTERN = r"(?P<instance>.+)\." \
r"(?P<dns>proxy-|cluster-|cluster-ro-|cluster-custom-)?" \
r"(?P<region>[a-zA-Z0-9\\-]+)\.rds\.amazonaws\.com)(?!\.cn)$"
AURORA_OLD_CHINA_DNS_PATTERN = r"^(?P<instance>.+)\." \
r"(?P<dns>proxy-|cluster-|cluster-ro-|cluster-custom-|limitless-)?" \
r"(?P<domain>[a-zA-Z0-9]+\." \
r"(?P<region>[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com\.cn)$"
AURORA_CHINA_DNS_PATTERN = r"^(?P<instance>.+)\." \
r"(?P<dns>proxy-|cluster-|cluster-ro-|cluster-custom-|limitless-)?" \
r"(?P<domain>[a-zA-Z0-9]+\." \
r"(?P<region>[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com\.cn)"
AURORA_CHINA_INSTANCE_PATTERN = r"(?P<instance>.+)\." \
r"(?P<domain>[a-zA-Z0-9]+\." \
r"(?P<region>[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com\.cn)"
AURORA_CHINA_CLUSTER_PATTERN = r"(?P<instance>.+)\." \
r"rds\.(?P<region>[a-zA-Z0-9\-]+)\.amazonaws\.com\.cn)$"
AURORA_OLD_CHINA_CLUSTER_PATTERN = r"^(?P<instance>.+)\." \
r"(?P<dns>cluster-|cluster-ro-)+" \
r"(?P<domain>[a-zA-Z0-9]+\." \
r"(?P<region>[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com\.cn)$"
AURORA_CHINA_CLUSTER_PATTERN = r"^(?P<instance>.+)\." \
r"(?P<dns>cluster-|cluster-ro-)+" \
r"(?P<domain>[a-zA-Z0-9]+\." \
r"(?P<region>[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com\.cn)"
AURORA_CHINA_CUSTOM_CLUSTER_PATTERN = r"(?P<instance>.+)\." \
r"(?P<dns>cluster-custom-)+" \
r"(?P<domain>[a-zA-Z0-9]+\." \
r"(?P<region>[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com\.cn)"
AURORA_CHINA_PROXY_DNS_PATTERN = r"(?P<instance>.+)\." \
r"(?P<dns>proxy-)+" \
r"(?P<domain>[a-zA-Z0-9]+\." \
r"(?P<region>[a-zA-Z0-9\-])+\.rds\.amazonaws\.com\.cn)"
r"rds\.(?P<region>[a-zA-Z0-9\-]+)\.amazonaws\.com\.cn)$"
AURORA_GOV_DNS_PATTERN = r"^(?P<instance>.+)\." \
r"(?P<dns>proxy-|cluster-|cluster-ro-|cluster-custom-|limitless-)?" \
r"(?P<domain>[a-zA-Z0-9]+\.rds\.(?P<region>[a-zA-Z0-9\-]+)" \
r"\.(amazonaws\.com|c2s\.ic\.gov|sc2s\.sgov\.gov))$"
AURORA_GOV_CLUSTER_PATTERN = r"^(?P<instance>.+)\." \
r"(?P<dns>cluster-|cluster-ro-)+" \
r"(?P<domain>[a-zA-Z0-9]+\.rds\.(?P<region>[a-zA-Z0-9\-]+)" \
r"\.(amazonaws\.com|c2s\.ic\.gov|sc2s\.sgov\.gov))$"
ELB_PATTERN = r"^(?<instance>.+)\.elb\.((?<region>[a-zA-Z0-9\-]+)\.amazonaws\.com)$"

IP_V4 = r"^(([1-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){1}" \
r"(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){2}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])$"
IP_V6 = r"^[0-9a-fA-F]{1,4}(:[0-9a-fA-F]{1,4}){7}$"
IP_V6_COMPRESSED = r"^(([0-9A-Fa-f]{1,4}(:[0-9A-Fa-f]{1,4}){0,5})?)::(([0-9A-Fa-f]{1,4}(:[0-9A-Fa-f]{1,4}){0,5})?)$"
r"(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){2}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])"
IP_V6 = r"^[0-9a-fA-F]{1,4}(:[0-9a-fA-F]{1,4}){7}"
IP_V6_COMPRESSED = r"^(([0-9A-Fa-f]{1,4}(:[0-9A-Fa-f]{1,4}){0,5})?)::(([0-9A-Fa-f]{1,4}(:[0-9A-Fa-f]{1,4}){0,5})?)"

DNS_GROUP = "dns"
DOMAIN_GROUP = "domain"
INSTANCE_GROUP = "instance"
REGION_GROUP = "region"

CACHE_DNS_PATTERNS: Dict[str, Match[str]] = {}
CACHE_PATTERNS: Dict[str, str] = {}

def is_rds_cluster_dns(self, host: str) -> bool:
return self._contains(host, [self.AURORA_CLUSTER_PATTERN, self.AURORA_CHINA_CLUSTER_PATTERN])
dns_group = self._get_dns_group(host)
return dns_group is not None and dns_group.casefold() in ["cluster-", "cluster-ro-"]

def is_rds_custom_cluster_dns(self, host: str) -> bool:
return self._contains(host, [self.AURORA_CUSTOM_CLUSTER_PATTERN, self.AURORA_CHINA_CUSTOM_CLUSTER_PATTERN])
dns_group = self._get_dns_group(host)
return dns_group is not None and dns_group.casefold() == "cluster-custom-"

def is_rds_dns(self, host: str) -> bool:
return self._contains(host, [self.AURORA_DNS_PATTERN, self.AURORA_CHINA_DNS_PATTERN])
if not host or not host.strip():
return False

pattern = self._find(host, [RdsUtils.AURORA_DNS_PATTERN,
RdsUtils.AURORA_CHINA_DNS_PATTERN,
RdsUtils.AURORA_OLD_CHINA_DNS_PATTERN,
RdsUtils.AURORA_GOV_DNS_PATTERN])
group = self._get_regex_group(pattern, RdsUtils.DNS_GROUP)

if group:
RdsUtils.CACHE_PATTERNS[host] = group

return pattern is not None

def is_rds_instance(self, host: str) -> bool:
return (self._contains(host, [self.AURORA_INSTANCE_PATTERN, self.AURORA_CHINA_INSTANCE_PATTERN])
and self.is_rds_dns(host))
return self._get_dns_group(host) is None and self.is_rds_dns(host)

def is_rds_proxy_dns(self, host: str) -> bool:
return self._contains(host, [self.AURORA_PROXY_DNS_PATTERN, self.AURORA_CHINA_PROXY_DNS_PATTERN])
dns_group = self._get_dns_group(host)
return dns_group is not None and dns_group.casefold() == "proxy-"

def get_rds_instance_host_pattern(self, host: str) -> str:
if not host or not host.strip():
return "?"

match = self._find(host, [self.AURORA_DNS_PATTERN, self.AURORA_CHINA_DNS_PATTERN])
match = self._get_group(host, RdsUtils.DOMAIN_GROUP)
if match:
return f"?.{match.group(self.DOMAIN_GROUP)}"
return f"?.{match}"

return "?"

def get_rds_region(self, host: Optional[str]):
if not host or not host.strip():
return None

match = self._find(host, [self.AURORA_DNS_PATTERN, self.AURORA_CHINA_DNS_PATTERN])
if match:
return match.group(self.REGION_GROUP)
group = self._get_group(host, RdsUtils.REGION_GROUP)
if group:
return group

elb_matcher = search(RdsUtils.ELB_PATTERN, host)
if elb_matcher:
return elb_matcher.group(RdsUtils.REGION_GROUP)
return None

def is_writer_cluster_dns(self, host: str) -> bool:
if not host or not host.strip():
return False

match = self._find(host, [self.AURORA_CLUSTER_PATTERN, self.AURORA_CHINA_CLUSTER_PATTERN])
if match:
return "cluster-".casefold() == match.group(self.DNS_GROUP).casefold()

return False
dns_group = self._get_dns_group(host)
return dns_group is not None and dns_group.casefold() == "cluster-"

def is_reader_cluster_dns(self, host: str) -> bool:
match = self._find(host, [self.AURORA_CLUSTER_PATTERN, self.AURORA_CHINA_CLUSTER_PATTERN])
if match:
return "cluster-ro-".casefold() == match.group(self.DNS_GROUP).casefold()

return False
dns_group = self._get_dns_group(host)
return dns_group is not None and dns_group.casefold() == "cluster-ro-"

def get_rds_cluster_host_url(self, host: str):
if not host or not host.strip():
return None

if search(self.AURORA_CLUSTER_PATTERN, host):
return sub(self.AURORA_CLUSTER_PATTERN, r"\g<instance>.cluster-\g<domain>", host)

if search(self.AURORA_CHINA_CLUSTER_PATTERN, host):
return sub(self.AURORA_CHINA_CLUSTER_PATTERN, r"\g<instance>.cluster-\g<domain>", host)
for pattern in [RdsUtils.AURORA_DNS_PATTERN,
RdsUtils.AURORA_CHINA_DNS_PATTERN,
RdsUtils.AURORA_OLD_CHINA_DNS_PATTERN,
RdsUtils.AURORA_GOV_DNS_PATTERN]:
if m := search(pattern, host):
group = self._get_regex_group(m, RdsUtils.DNS_GROUP)
if group is not None:
return sub(pattern, r"\g<instance>.cluster-\g<domain>", host)
return None

return None

def get_instance_id(self, host: str) -> Optional[str]:
if not host or not host.strip():
return None

match = self._find(host, [self.AURORA_INSTANCE_PATTERN, self.AURORA_CHINA_INSTANCE_PATTERN])
if match:
return match.group(self.INSTANCE_GROUP)
if self._get_dns_group(host) is None:
return self._get_group(host, self.INSTANCE_GROUP)

return None

def is_ipv4(self, host: str) -> bool:
return self._contains(host, [self.IP_V4])
if host is None or not host.strip():
return False
return search(RdsUtils.IP_V4, host) is not None

def is_ipv6(self, host: str) -> bool:
return self._contains(host, [self.IP_V6, self.IP_V6_COMPRESSED])
if host is None or not host.strip():
return False
return search(RdsUtils.IP_V6_COMPRESSED, host) is not None or search(RdsUtils.IP_V6, host) is not None

def is_dns_pattern_valid(self, host: str) -> bool:
return "?" in host
Expand All @@ -210,17 +233,48 @@ def identify_rds_type(self, host: Optional[str]) -> RdsUrlType:

return RdsUrlType.OTHER

def _contains(self, host: str, patterns: list) -> bool:
if not host or not host.strip():
return False

return len([pattern for pattern in patterns if search(pattern, host)]) > 0

def _find(self, host: str, patterns: list):
if not host or not host.strip():
return None

for pattern in patterns:
match = RdsUtils.CACHE_DNS_PATTERNS.get(host)
if match:
return match

match = search(pattern, host)
if match:
RdsUtils.CACHE_DNS_PATTERNS[host] = match
return match

return None

def _get_regex_group(self, pattern: Match[str], group_name: str):
if pattern is None:
return None
return pattern.group(group_name)

def _get_group(self, host: str, group: str):
if not host or not host.strip():
return None

pattern = self._find(host, [RdsUtils.AURORA_DNS_PATTERN,
RdsUtils.AURORA_CHINA_DNS_PATTERN,
RdsUtils.AURORA_OLD_CHINA_DNS_PATTERN,
RdsUtils.AURORA_GOV_DNS_PATTERN])
return self._get_regex_group(pattern, group)

def _get_dns_group(self, host: str):
return self._get_group(host, RdsUtils.DNS_GROUP)

def remove_port(self, url: str):
if not url or not url.strip():
return None
if ":" in url:
return url.split(":")[0]
return url

@staticmethod
def clear_cache():
RdsUtils.CACHE_PATTERNS.clear()
RdsUtils.CACHE_DNS_PATTERNS.clear()
2 changes: 2 additions & 0 deletions tests/integration/container/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from aws_advanced_python_wrapper.host_list_provider import RdsHostListProvider
from aws_advanced_python_wrapper.plugin_service import PluginServiceImpl
from aws_advanced_python_wrapper.utils.log import Logger
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils

if TYPE_CHECKING:
from .utils.test_driver import TestDriver
Expand Down Expand Up @@ -124,6 +125,7 @@ def pytest_runtest_setup(item):

assert cluster_ip == writer_ip

RdsUtils.clear_cache()
RdsHostListProvider._topology_cache.clear()
RdsHostListProvider._is_primary_cluster_id_cache.clear()
RdsHostListProvider._cluster_ids_to_update.clear()
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/container/test_autoscaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def test_pooled_connection_auto_scaling__set_read_only_on_old_connection(
provider = SqlAlchemyPooledConnectionProvider(
lambda _, __: {"pool_size": original_cluster_size},
None,
None,
120000000000, # 2 minutes
180000000000) # 3 minutes
ConnectionProviderManager.set_connection_provider(provider)
Expand Down Expand Up @@ -167,6 +168,7 @@ def test_pooled_connection_auto_scaling__failover_from_deleted_reader(
provider = SqlAlchemyPooledConnectionProvider(
lambda _, __: {"pool_size": len(instances) * 5},
None,
None,
120000000000, # 2 minutes
180000000000) # 3 minutes
ConnectionProviderManager.set_connection_provider(provider)
Expand Down
Loading

0 comments on commit eeebefa

Please sign in to comment.