Skip to content

Commit

Permalink
[source-hubspot] bump cdk, implement rfr for contacts_form_submission…
Browse files Browse the repository at this point in the history
…s, contact_list_memberships, contact_merged_audit, and mock server tests
  • Loading branch information
brianjlai committed May 7, 2024
1 parent b8838a4 commit 1d3d4cc
Show file tree
Hide file tree
Showing 16 changed files with 1,242 additions and 109 deletions.
353 changes: 318 additions & 35 deletions airbyte-integrations/connectors/source-hubspot/poetry.lock

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from airbyte_cdk.entrypoint import logger
from airbyte_cdk.models import FailureType, SyncMode
from airbyte_cdk.sources import Source
from airbyte_cdk.sources.streams import IncrementalMixin, Stream
from airbyte_cdk.sources.streams import CheckpointMixin, Stream
from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy
from airbyte_cdk.sources.streams.core import StreamData
from airbyte_cdk.sources.streams.http import HttpStream, HttpSubStream
Expand Down Expand Up @@ -838,7 +838,7 @@ def availability_strategy(self) -> Optional[AvailabilityStrategy]:
return HubspotAvailabilityStrategy()


class ClientSideIncrementalStream(Stream, IncrementalMixin):
class ClientSideIncrementalStream(Stream, CheckpointMixin):
_cursor_value = ""

@property
Expand All @@ -861,7 +861,10 @@ def state(self) -> Mapping[str, Any]:

@state.setter
def state(self, value: Mapping[str, Any]):
self._cursor_value = value[self.cursor_field]
if value:
self._cursor_value = value[self.cursor_field]
else:
self._cursor_value = ""

def filter_by_state(self, stream_state: Mapping[str, Any] = None, record: Mapping[str, Any] = None) -> bool:
"""
Expand Down Expand Up @@ -1232,7 +1235,7 @@ def stream_slices(
self, *, sync_mode: SyncMode, cursor_field: List[str] = None, stream_state: Mapping[str, Any] = None
) -> Iterable[Optional[Mapping[str, Any]]]:
self.set_sync(sync_mode, stream_state)
return [None]
return [{}] # I changed this from [None] since this is a more accurate depiction of what is actually being done. Sync one slice

def set_sync(self, sync_mode: SyncMode, stream_state):
self._sync_mode = sync_mode
Expand Down Expand Up @@ -1361,7 +1364,7 @@ class ContactLists(IncrementalStream):
unnest_fields = ["metaData"]


class ContactsAllBase(Stream):
class ContactsAllBase(Stream, CheckpointMixin):
url = "/contacts/v1/lists/all/contacts/all"
updated_at_field = "timestamp"
more_key = "has-more"
Expand All @@ -1374,6 +1377,62 @@ class ContactsAllBase(Stream):
records_field = None
filter_field = None
filter_value = None
_state = {}
limit_field = "count"
limit = 100

@property
def state(self) -> MutableMapping[str, Any]:
return self._state

@state.setter
def state(self, value: MutableMapping[str, Any]) -> None:
self._state = value

def read_records(
self,
sync_mode: SyncMode,
cursor_field: List[str] = None,
stream_slice: Mapping[str, Any] = None,
stream_state: Mapping[str, Any] = None,
) -> Iterable[Mapping[str, Any]]:
"""
This is a specialized read_records for resumable full refresh that only attempts to read a single page of records
at a time and updates the state w/ a synthetic cursor based on the Hubspot cursor pagination value `vidOffset`
"""

next_page_token = stream_slice
try:
properties = self._property_wrapper
if properties and properties.too_many_properties:
records, response = self._read_stream_records(
stream_slice=stream_slice,
stream_state=stream_state,
next_page_token=next_page_token,
)
else:
response = self.handle_request(
stream_slice=stream_slice,
stream_state=stream_state,
next_page_token=next_page_token,
properties=properties,
)
records = self._transform(self.parse_response(response, stream_state=stream_state, stream_slice=stream_slice))

if self.filter_old_records:
records = self._filter_old_records(records)
yield from self.record_unnester.unnest(records)

self.state = self.next_page_token(response) or {}

# Always return an empty generator just in case no records were ever yielded
yield from []
except requests.exceptions.HTTPError as e:
response = e.response
if response.status_code == HTTPStatus.UNAUTHORIZED:
raise AirbyteTracedException("The authentication to HubSpot has expired. Re-authenticate to restore access to HubSpot.")
else:
raise e

def _transform(self, records: Iterable) -> Iterable:
for record in super()._transform(records):
Expand Down Expand Up @@ -1415,6 +1474,12 @@ class ContactsFormSubmissions(ContactsAllBase, ABC):
filter_value = "all"


class ContactsMergedAudit(ContactsAllBase, ABC):

records_field = "merge-audits"
unnest_fields = ["merged_from_email", "merged_to_email"]


class Deals(CRMSearchStream):
"""Deals, API v3"""

Expand Down Expand Up @@ -2101,65 +2166,6 @@ class Contacts(CRMSearchStream):
scopes = {"crm.objects.contacts.read"}


class ContactsMergedAudit(Stream):
url = "/contacts/v1/contact/vids/batch/"
updated_at_field = "timestamp"
scopes = {"crm.objects.contacts.read"}
unnest_fields = ["merged_from_email", "merged_to_email"]

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.config = kwargs

def stream_slices(
self, sync_mode: SyncMode, cursor_field: List[str] = None, stream_state: Mapping[str, Any] = None, **kwargs
) -> Iterable[Mapping[str, Any]]:
slices = []

# we can query a max of 100 contacts at a time
max_contacts = 100
slices = []
contact_batch = []

contacts = Contacts(**self.config)
contacts._sync_mode = SyncMode.full_refresh
contacts.filter_old_records = False

for contact in contacts.read_records(sync_mode=SyncMode.full_refresh):
if contact.get("properties_hs_merged_object_ids"):
contact_batch.append(contact["id"])

if len(contact_batch) == max_contacts:
slices.append({"vid": contact_batch})
contact_batch = []

if contact_batch:
slices.append({"vid": contact_batch})

return slices

def request_params(
self, stream_state: Mapping[str, Any], stream_slice: Mapping[str, Any] = None, next_page_token: Mapping[str, Any] = None
) -> MutableMapping[str, Any]:
return {"vid": stream_slice["vid"]}

def parse_response(
self,
response: requests.Response,
*,
stream_state: Mapping[str, Any],
stream_slice: Mapping[str, Any] = None,
next_page_token: Mapping[str, Any] = None,
) -> Iterable[Mapping]:
response = self._parse_response(response)
if response.get("status", None) == "error":
self.logger.warning(f"Stream `{self.name}` cannot be procced. {response.get('message')}")
return

for contact_id in list(response.keys()):
yield from response[contact_id]["merge-audits"]


class EngagementsCalls(CRMSearchStream):
entity = "calls"
last_modified_field = "hs_lastmodifieddate"
Expand Down Expand Up @@ -2273,7 +2279,7 @@ class EmailSubscriptions(Stream):
filter_old_records = False


class WebAnalyticsStream(IncrementalMixin, HttpSubStream, Stream):
class WebAnalyticsStream(CheckpointMixin, HttpSubStream, Stream):
"""
A base class for Web Analytics API
Docs: https://developers.hubspot.com/docs/api/events/web-analytics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .request_builders.streams import CRMStreamRequestBuilder, IncrementalCRMStreamRequestBuilder, WebAnalyticsRequestBuilder
from .response_builder.helpers import RootHttpResponseBuilder
from .response_builder.api import ScopesResponseBuilder
from .response_builder.pagination import HubspotCursorPaginationStrategy
from .response_builder.streams import GenericResponseBuilder, HubspotStreamResponseBuilder


Expand Down Expand Up @@ -131,3 +132,28 @@ def read_from_stream(
cls, cfg, stream: str, sync_mode: SyncMode, state: Optional[List[AirbyteStateMessage]] = None, expecting_exception: bool = False
) -> EntrypointOutput:
return read(SourceHubspot(), cfg, cls.catalog(stream, sync_mode), state, expecting_exception)


@freezegun.freeze_time("2024-05-05T00:00:00Z")
class HubspotContactsTestCase(HubspotTestCase):
CHECKPOINT_FIELD = "vid"

@classmethod
def response_builder(cls, stream_name) -> HubspotStreamResponseBuilder:
return HubspotStreamResponseBuilder.for_stream(stream_name, "contacts", HubspotCursorPaginationStrategy())

@classmethod
def response(cls, stream_name, with_pagination: bool = False) -> HubspotStreamResponseBuilder:
record = cls.record_builder(stream_name, FieldPath(cls.CHECKPOINT_FIELD)).with_field(
FieldPath("id"), cls.OBJECT_ID
)
response = cls.response_builder(stream_name=stream_name).with_record(record).with_record(record)
if with_pagination:
response = response.with_pagination()
return response

@classmethod
def record_builder(cls, stream: str, record_cursor_path):
return create_record_builder(
find_template(stream, __file__), records_path=FieldPath("contacts"), record_id_path=None, record_cursor_path=record_cursor_path
)
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,37 @@ def with_page_token(self, next_page_token: Dict):
def build(self):
q = "&".join(filter(None, self._query_params))
return HttpRequest(self.URL, query_params=q)


# We only need to mock the Contacts endpoint because it services the data extracted by ListMemberships, FormSubmissions, MergedAudit
class ContactsStreamRequestBuilder(AbstractRequestBuilder):
URL = "https://api.hubapi.com/contacts/v1/lists/all/contacts/all"

def __init__(self):
self._filters = []
self._vid_offset = None

@property
def _count(self):
return "count=100"

def with_filter(self, filter_field: str, filter_value: Any):
self._filters.append(f"{filter_field}={filter_value}")
return self

def with_vid_offset(self, vid_offset: str):
self._vid_offset = f"vidOffset={vid_offset}"
return self

@property
def _query_params(self):
params = [
self._count,
self._vid_offset,
]
params.extend(self._filters)
return filter(None, params)

def build(self):
q = "&".join(filter(None, self._query_params))
return HttpRequest(self.URL, query_params=q)
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,9 @@ def update(self, response: Dict[str, Any]) -> None:
"link": None
}
}


class HubspotCursorPaginationStrategy(PaginationStrategy):
def update(self, response: Dict[str, Any]) -> None:
response["has-more"] = True
response["vid-offset"] = "5331889818"
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import json

from airbyte_cdk.test.mock_http import HttpResponse
from airbyte_cdk.test.mock_http.response_builder import FieldPath, HttpResponseBuilder, find_template
from airbyte_cdk.test.mock_http.response_builder import FieldPath, HttpResponseBuilder, PaginationStrategy, find_template

from . import AbstractResponseBuilder
from .pagination import HubspotPaginationStrategy


class HubspotStreamResponseBuilder(HttpResponseBuilder):
Expand All @@ -15,8 +14,8 @@ def pagination_strategy(self):
return self._pagination_strategy

@classmethod
def for_stream(cls, stream: str):
return cls(find_template(stream, __file__), FieldPath("results"), HubspotPaginationStrategy())
def for_stream(cls, stream: str, records_path: str, pagination_strategy: PaginationStrategy):
return cls(find_template(stream, __file__), FieldPath(records_path), pagination_strategy)


class GenericResponseBuilder(AbstractResponseBuilder):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.

import freezegun
from airbyte_cdk.test.mock_http import HttpMocker
from airbyte_protocol.models import AirbyteStateBlob, AirbyteStateMessage, AirbyteStateType, AirbyteStreamState, StreamDescriptor, SyncMode

from . import HubspotContactsTestCase
from .request_builders.streams import ContactsStreamRequestBuilder


@freezegun.freeze_time("2024-05-04T00:00:00Z")
class TestContactsFormSubmissionsStream(HubspotContactsTestCase):
SCOPES = ["crm.objects.contacts.read"]
STREAM_NAME = "contacts_form_submissions"

@HttpMocker()
def test_read_multiple_contact_pages(self, http_mocker: HttpMocker):
self.mock_oauth(http_mocker, self.ACCESS_TOKEN)
self.mock_scopes(http_mocker, self.ACCESS_TOKEN, self.SCOPES)
self.mock_custom_objects(http_mocker)

self.mock_response(http_mocker, ContactsStreamRequestBuilder().with_filter("formSubmissionMode", "all").build(), self.response(stream_name=self.STREAM_NAME, with_pagination=True).build())
self.mock_response(http_mocker, ContactsStreamRequestBuilder().with_filter("formSubmissionMode", "all").with_vid_offset("5331889818").build(), self.response(stream_name=self.STREAM_NAME).build())

output = self.read_from_stream(self.oauth_config(), self.STREAM_NAME, SyncMode.full_refresh)

assert len(output.records) == 16
assert output.state_messages[0].state.stream.stream_state.dict() == {"vidOffset": "5331889818"}
assert output.state_messages[0].state.stream.stream_descriptor.name == self.STREAM_NAME
assert output.state_messages[0].state.sourceStats.recordCount == 8
assert output.state_messages[1].state.stream.stream_state.dict() == {}
assert output.state_messages[1].state.stream.stream_descriptor.name == self.STREAM_NAME
assert output.state_messages[1].state.sourceStats.recordCount == 8

@HttpMocker()
def test_read_from_incoming_state(self, http_mocker: HttpMocker):
state = [
AirbyteStateMessage(
type=AirbyteStateType.STREAM,
stream=AirbyteStreamState(
stream_descriptor=StreamDescriptor(name=self.STREAM_NAME),
stream_state=AirbyteStateBlob(**{"vidOffset": "5331889818"})
)
)
]

self.mock_oauth(http_mocker, self.ACCESS_TOKEN)
self.mock_scopes(http_mocker, self.ACCESS_TOKEN, self.SCOPES)
self.mock_custom_objects(http_mocker)

# Even though we only care about the request with a vidOffset parameter, we mock this in order to pass the availability check
self.mock_response(http_mocker, ContactsStreamRequestBuilder().with_filter("formSubmissionMode", "all").build(), self.response(stream_name=self.STREAM_NAME, with_pagination=True).build())
self.mock_response(http_mocker, ContactsStreamRequestBuilder().with_filter("formSubmissionMode", "all").with_vid_offset("5331889818").build(), self.response(stream_name=self.STREAM_NAME).build())

output = self.read_from_stream(cfg=self.oauth_config(), stream=self.STREAM_NAME, sync_mode=SyncMode.full_refresh, state=state)

assert len(output.records) == 8
assert output.state_messages[0].state.stream.stream_state.dict() == {}
assert output.state_messages[0].state.stream.stream_descriptor.name == self.STREAM_NAME
assert output.state_messages[0].state.sourceStats.recordCount == 8
Loading

0 comments on commit 1d3d4cc

Please sign in to comment.