Skip to content

Commit

Permalink
CDK: Fix request_cache clearing and move it to tmp folder (#30719)
Browse files Browse the repository at this point in the history
Co-authored-by: Eugene Kulak <[email protected]>
  • Loading branch information
keu and eugene-kulak authored Sep 28, 2023
1 parent 2bc7f34 commit 5eba3c3
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 22 deletions.
2 changes: 0 additions & 2 deletions airbyte-cdk/python/.gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
.coverage

# TODO: these are tmp files generated by unit tests. They should go to the /tmp directory.
cache_http_stream*.yml
<MagicMock*
cache*.sqlite
1 change: 1 addition & 0 deletions airbyte-cdk/python/airbyte_cdk/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def run(self, parsed_args: argparse.Namespace) -> Iterable[str]:
source_spec: ConnectorSpecification = self.source.spec(self.logger)
try:
with tempfile.TemporaryDirectory() as temp_dir:
os.environ["TMPDIR"] = temp_dir # set this as default temp directory (used by requests_cache to store *.sqlite files)
if cmd == "spec":
message = AirbyteMessage(type=Type.SPEC, spec=source_spec)
yield from [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2021 Airbyte, Inc., all rights reserved.
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

# Initialize Streams Package
Expand Down
15 changes: 5 additions & 10 deletions airbyte-cdk/python/airbyte_cdk/sources/streams/http/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@


import logging
import os
import urllib
from abc import ABC, abstractmethod
from contextlib import suppress
from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Union
from urllib.parse import urljoin

Expand Down Expand Up @@ -64,18 +62,15 @@ def use_cache(self) -> bool:
return False

def request_cache(self) -> requests.Session:
self.clear_cache()
return requests_cache.CachedSession(self.cache_filename)
backend = requests_cache.SQLiteCache(use_temp=True)
return requests_cache.CachedSession(self.cache_filename, backend=backend)

def clear_cache(self) -> None:
"""
remove cache file only once
clear cached requests for current session, can be called any time
"""
STREAM_CACHE_FILES = globals().setdefault("STREAM_CACHE_FILES", set())
if self.cache_filename not in STREAM_CACHE_FILES:
with suppress(FileNotFoundError):
os.remove(self.cache_filename)
STREAM_CACHE_FILES.add(self.cache_filename)
if isinstance(self._session, requests_cache.CachedSession):
self._session.cache.clear()

@property
@abstractmethod
Expand Down
26 changes: 17 additions & 9 deletions airbyte-cdk/python/unit_tests/sources/streams/http/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,13 +395,20 @@ def test_parent_attribute_exist():
assert child_stream.parent == parent_stream


def test_cache_response(mocker):
def test_that_response_was_cached(mocker, requests_mock):
requests_mock.register_uri("GET", "https://google.com/", text="text")
stream = CacheHttpStream()
stream.clear_cache()
mocker.patch.object(stream, "url_base", "https://google.com/")
list(stream.read_records(sync_mode=SyncMode.full_refresh))
records = list(stream.read_records(sync_mode=SyncMode.full_refresh))

assert requests_mock.called

requests_mock.reset_mock()
new_records = list(stream.read_records(sync_mode=SyncMode.full_refresh))

with open(stream.cache_filename, "rb") as f:
assert f.read()
assert len(records) == len(new_records)
assert not requests_mock.called


class CacheHttpStreamWithSlices(CacheHttpStream):
Expand All @@ -425,25 +432,26 @@ def test_using_cache(mocker, requests_mock):

parent_stream = CacheHttpStreamWithSlices()
mocker.patch.object(parent_stream, "url_base", "https://google.com/")
parent_stream.clear_cache()

assert requests_mock.call_count == 0
assert parent_stream._session.cache.response_count() == 0
assert len(parent_stream._session.cache.responses) == 0

for _slice in parent_stream.stream_slices():
list(parent_stream.read_records(sync_mode=SyncMode.full_refresh, stream_slice=_slice))

assert requests_mock.call_count == 2
assert parent_stream._session.cache.response_count() == 2
assert len(parent_stream._session.cache.responses) == 2

child_stream = CacheHttpSubStream(parent=parent_stream)

for _slice in child_stream.stream_slices(sync_mode=SyncMode.full_refresh):
pass

assert requests_mock.call_count == 2
assert parent_stream._session.cache.response_count() == 2
assert parent_stream._session.cache.has_url("https://google.com/")
assert parent_stream._session.cache.has_url("https://google.com/search")
assert len(parent_stream._session.cache.responses) == 2
assert parent_stream._session.cache.contains(url="https://google.com/")
assert parent_stream._session.cache.contains(url="https://google.com/search")


class AutoFailTrueHttpStream(StubBasicReadHttpStream):
Expand Down

0 comments on commit 5eba3c3

Please sign in to comment.