From 44b745fbd825bdd70089c4264da8be71c6369a6b Mon Sep 17 00:00:00 2001 From: Chris Sellers Date: Tue, 13 Aug 2024 14:18:04 +1000 Subject: [PATCH] MOD: Stream compressed for corporate actions API --- databento/reference/api/corporate.py | 9 ++++++++- tests/test_reference_corporate.py | 12 +++++++----- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/databento/reference/api/corporate.py b/databento/reference/api/corporate.py index 177ebce..dbacd40 100644 --- a/databento/reference/api/corporate.py +++ b/databento/reference/api/corporate.py @@ -2,9 +2,12 @@ from collections.abc import Iterable from datetime import date +from io import BytesIO from io import StringIO import pandas as pd +import zstandard +from databento_dbn import Compression from databento_dbn import SType from databento.common import API_VERSION @@ -112,6 +115,7 @@ def get_range( "events": ",".join(events) if events else None, "countries": ",".join(countries) if countries else None, "security_types": ",".join(security_types) if security_types else None, + "compression": str(Compression.ZSTD), # Always request zstd } response = self._post( @@ -120,7 +124,10 @@ def get_range( basic_auth=True, ) - df = pd.read_json(StringIO(response.text), lines=True) + decompressor = zstandard.ZstdDecompressor() + decompressed_content = decompressor.stream_reader(BytesIO(response.content)).read() + + df = pd.read_json(StringIO(decompressed_content.decode()), lines=True) if df.empty: return df diff --git a/tests/test_reference_corporate.py b/tests/test_reference_corporate.py index eac21ec..4461b29 100644 --- a/tests/test_reference_corporate.py +++ b/tests/test_reference_corporate.py @@ -8,6 +8,7 @@ import pandas as pd import pytest import requests +import zstandard from databento.reference.client import Reference from tests import TESTS_ROOT @@ -77,7 +78,7 @@ def test_corporate_actions_get_range_sends_expected_request( ) -> None: # Arrange mock_response = MagicMock() - mock_response.text = "{}" + mock_response.content = zstandard.compress(b"{}") mock_response.__enter__.return_value = mock_response mock_response.__exit__ = MagicMock() monkeypatch.setattr(requests, "post", mock_post := MagicMock(return_value=mock_response)) @@ -110,6 +111,7 @@ def test_corporate_actions_get_range_sends_expected_request( "events": expected_events, "countries": expected_countries, "security_types": expected_security_types, + "compression": "zstd", } assert call["timeout"] == (100, 100) assert isinstance(call["auth"], requests.auth.HTTPBasicAuth) @@ -122,7 +124,7 @@ def test_corporate_actions_get_range_response_parsing_as_pit( # Arrange data_path = Path(TESTS_ROOT) / "data" / "REFERENCE" / "test_data.corporate-actions.ndjson" mock_response = MagicMock() - mock_response.text = data_path.read_text() + mock_response.content = zstandard.compress(data_path.read_bytes()) mock_response.__enter__.return_value = mock_response mock_response.__exit__ = MagicMock() monkeypatch.setattr(requests, "post", MagicMock(return_value=mock_response)) @@ -152,7 +154,7 @@ def test_corporate_actions_get_range_response( # Arrange data_path = Path(TESTS_ROOT) / "data" / "REFERENCE" / "test_data.corporate-actions-pit.ndjson" mock_response = MagicMock() - mock_response.text = data_path.read_text() + mock_response.content = zstandard.compress(data_path.read_bytes()) mock_response.__enter__.return_value = mock_response mock_response.__exit__ = MagicMock() monkeypatch.setattr(requests, "post", MagicMock(return_value=mock_response)) @@ -178,7 +180,7 @@ def test_corporate_actions_get_range_with_ts_record_index( # Arrange data_path = Path(TESTS_ROOT) / "data" / "REFERENCE" / "test_data.corporate-actions.ndjson" mock_response = MagicMock() - mock_response.text = data_path.read_text() + mock_response.content = zstandard.compress(data_path.read_bytes()) mock_response.__enter__.return_value = mock_response mock_response.__exit__ = MagicMock() monkeypatch.setattr(requests, "post", MagicMock(return_value=mock_response)) @@ -212,7 +214,7 @@ def test_corporate_actions_get_range_without_flattening( # Arrange data_path = Path(TESTS_ROOT) / "data" / "REFERENCE" / "test_data.corporate-actions.ndjson" mock_response = MagicMock() - mock_response.text = data_path.read_text() + mock_response.content = zstandard.compress(data_path.read_bytes()) mock_response.__enter__.return_value = mock_response mock_response.__exit__ = MagicMock() monkeypatch.setattr(requests, "post", MagicMock(return_value=mock_response))