Skip to content

Commit

Permalink
add cached file to zeep transport
Browse files Browse the repository at this point in the history
  • Loading branch information
austinmroczek committed May 4, 2024
1 parent c18db86 commit 20373cb
Showing 1 changed file with 62 additions and 29 deletions.
91 changes: 62 additions & 29 deletions total_connect_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@
import logging
import ssl
import time
from importlib import resources as impresources

import requests
import urllib3.poolmanager
import zeep
import zeep.cache
from zeep.exceptions import Fault as ZeepFault
import zeep.transports
import requests
import urllib3.poolmanager
from zeep.exceptions import Fault as ZeepFault

from . import cache as cache_folder
from .const import ArmType, _ResultCode
from .exceptions import (
AuthenticationError,
Expand All @@ -37,27 +39,35 @@

DEFAULT_USERCODE = "-1"

SCHEMAS_TO_CACHE = {
"https://schemas.xmlsoap.org/soap/encoding/": "soap-encodings-schemas.xmlsoap.org.txt",
}

LOGGER = logging.getLogger(__name__)


class _SslContextAdapter(requests.adapters.HTTPAdapter):
"""Makes Zeep use our ssl_context."""

def __init__(self, ssl_context, **kwargs):
self.ssl_context = ssl_context
super().__init__(**kwargs)

def init_poolmanager(self, num_pools, maxsize, block=False):
self.poolmanager = urllib3.poolmanager.PoolManager(
num_pools=num_pools, maxsize=maxsize,
block=block, ssl_context=self.ssl_context)
num_pools=num_pools,
maxsize=maxsize,
block=block,
ssl_context=self.ssl_context,
)


class TotalConnectClient:
"""Client for Total Connect."""

TIMEOUT = 60 # seconds until SOAP I/O will fail

def __init__( # pylint: disable=too-many-arguments
def __init__( # pylint: disable=too-many-arguments
self,
username,
password,
Expand Down Expand Up @@ -95,7 +105,7 @@ def locations(self):
"""
# to_fetch is needed because items() is invalidated by del
to_fetch = list(self._locations_unfetched.items())
for (locationid, location) in to_fetch:
for locationid, location in to_fetch:
try:
location.get_partition_details()
location.get_zone_details()
Expand Down Expand Up @@ -154,21 +164,22 @@ def _raise_for_retry(self, response):
if rc == _ResultCode.FAILED_TO_CONNECT:
raise RetryableTotalConnectError("failed to connect with panel", response)
if rc == _ResultCode.AUTHENTICATION_FAILED:
raise RetryableTotalConnectError("temporary authentication failure", response)
raise RetryableTotalConnectError(
"temporary authentication failure", response
)
if rc == _ResultCode.BAD_OBJECT_REFERENCE:
raise RetryableTotalConnectError("bad object reference", response)


def raise_for_resultcode(self, response):
"""If response.ResultCode indicates success, return and do nothing.
If it indicates an authentication error, raise AuthenticationError.
"""
rc = _ResultCode.from_response(response)
if rc in (
_ResultCode.SUCCESS,
_ResultCode.ARM_SUCCESS,
_ResultCode.DISARM_SUCCESS,
_ResultCode.SESSION_INITIATED,
_ResultCode.SUCCESS,
_ResultCode.ARM_SUCCESS,
_ResultCode.DISARM_SUCCESS,
_ResultCode.SESSION_INITIATED,
):
return
self._raise_for_retry(response)
Expand All @@ -181,7 +192,7 @@ def raise_for_resultcode(self, response):
if rc == _ResultCode.FEATURE_NOT_SUPPORTED:
raise FeatureNotSupportedError(rc.name, response)
if rc == _ResultCode.FAILED_TO_BYPASS_ZONE:
raise FailedToBypassZone(rc.name, response)
raise FailedToBypassZone(rc.name, response)
raise BadResultCodeError(rc.name, response)

def _send_one_request(self, operation_name, args):
Expand All @@ -206,10 +217,15 @@ def request(self, operation_name, args, attempts_remaining=5):
session = requests.Session()
ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
ctx.options |= 0x04 # ssl.OP_LEGACY_SERVER_CONNECT once that exists
session.mount('https://', _SslContextAdapter(ctx))
session.mount("https://", _SslContextAdapter(ctx))
cache = zeep.cache.InMemoryCache()
for url, filename in SCHEMAS_TO_CACHE.items():
cache_file = impresources.files(cache_folder) / filename
with cache_file.open() as file:
cache.add(url, file.read())
transport = zeep.transports.Transport(
session=session,
cache=zeep.cache.InMemoryCache(timeout=3600),
cache=cache,
timeout=self.TIMEOUT, # for loading WSDL and xsd documents
operation_timeout=self.TIMEOUT, # for operations (POST/GET)
)
Expand All @@ -222,27 +238,40 @@ def request(self, operation_name, args, attempts_remaining=5):
# request, add it to an except block here, depending on what
# you want to have happen. The first block just retries and
# logs. The second block causes reauthentication.
except (RetryableTotalConnectError, requests.exceptions.RequestException) as err:
except (
RetryableTotalConnectError,
requests.exceptions.RequestException,
) as err:
if attempts_remaining <= 0:
raise
if isinstance(err, RetryableTotalConnectError):
msg = f"{self.username} {operation_name}{args} {err.args[0]} on response"
msg = (
f"{self.username} {operation_name}{args} {err.args[0]} on response"
)
else:
msg = f"{self.username} {operation_name}{args} {err} on request"
if is_first_request:
LOGGER.info(f"{msg}: {attempts_remaining} retries remaining")
else:
LOGGER.debug(f"{msg}: {attempts_remaining} retries remaining")
time.sleep(self.retry_delay)
except ZeepFault as err:
except (ZeepFault, requests.exceptions.HTTPError) as err:
if attempts_remaining <= 0:
raise ServiceUnavailable(f"Error connecting to Total Connect service: {err}") from err
LOGGER.debug(f"Error connecting to Total Connect service: {attempts_remaining} retries remaining")
raise ServiceUnavailable(
f"Error connecting to Total Connect service: {err}"
) from err
LOGGER.debug(
f"Error connecting to Total Connect service: {attempts_remaining} retries remaining"
)
time.sleep(self.retry_delay)
except InvalidSessionError as err:
if attempts_remaining <= 0:
raise ServiceUnavailable(f"Invalid Session after multiple retries: {err}") from err
LOGGER.info(f"reauthenticating {self.username}: {attempts_remaining} retries remaining")
raise ServiceUnavailable(
f"Invalid Session after multiple retries: {err}"
) from err
LOGGER.info(
f"reauthenticating {self.username}: {attempts_remaining} retries remaining"
)
old_token = self.token
self.token = None
self.authenticate()
Expand All @@ -264,9 +293,10 @@ def authenticate(self):
operation_name = (
"AuthenticateUserLogin" if self._locations else "LoginAndGetSessionDetails"
)
response = self.request(operation_name, (
self.username, self.password, self.API_APP_ID, self.API_APP_VERSION
))
response = self.request(
operation_name,
(self.username, self.password, self.API_APP_ID, self.API_APP_VERSION),
)
try:
self.raise_for_resultcode(response)
except AuthenticationError:
Expand All @@ -289,7 +319,9 @@ def authenticate(self):

def validate_usercode(self, device_id, usercode):
"""Return True if the usercode is valid for the device."""
response = self.request("ValidateUserCode", (self.token, device_id, str(usercode)))
response = self.request(
"ValidateUserCode", (self.token, device_id, str(usercode))
)
try:
self.raise_for_resultcode(response)
except UsercodeInvalid:
Expand Down Expand Up @@ -336,8 +368,9 @@ def _make_locations(self, response):

# set the usercode for the location
usercode = (
self.usercodes.get(location_id) or # noqa: W504
self.usercodes.get(str(location_id)) or self.usercodes.get("default")
self.usercodes.get(location_id) # noqa: W504
or self.usercodes.get(str(location_id))
or self.usercodes.get("default")
)
if usercode:
location.usercode = usercode
Expand Down

0 comments on commit 20373cb

Please sign in to comment.