Skip to content

Commit

Permalink
Added start date and glob filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
lazebnyi committed Oct 12, 2023
1 parent 868f662 commit b1aae06
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,6 @@
},
"sync_mode": "full_refresh",
"destination_sync_mode": "overwrite"
},
{
"stream": {
"name": "example_2",
"json_schema": {},
"supported_sync_modes": ["full_refresh"]
},
"sync_mode": "full_refresh",
"destination_sync_mode": "overwrite"
}
]
}
8 changes: 2 additions & 6 deletions airbyte-integrations/connectors/source-gcs/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,11 @@

MAIN_REQUIREMENTS = [
"airbyte-cdk>=0.51.17",
"google-cloud-storage==2.5.0",
"google-cloud-storage==2.12.0",
"pandas==1.5.3",
"pyarrow==12.0.1",
"smart-open[s3]==5.1.0",
"wcmatch==8.4",
"dill==0.3.4",
"pytz",
"fastavro==1.4.11",
"python-snappy==0.6.1",
"smart-open[s3]==5.1.0"
]

TEST_REQUIREMENTS = [
Expand Down
63 changes: 53 additions & 10 deletions airbyte-integrations/connectors/source-gcs/source_gcs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,75 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

from typing import List, Optional

from airbyte_cdk.sources.file_based.config.abstract_file_based_spec import AbstractFileBasedSpec
from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig
from pydantic import AnyUrl, Field


class SourceGCSStreamConfig(FileBasedStreamConfig):
name: str = Field(title="Name", description="The name of the stream.", order=0)
globs: Optional[List[str]] = Field(
title="Globs",
description="The pattern used to specify which files should be selected from the file system. For more information on glob "
'pattern matching look <a href="https://en.wikipedia.org/wiki/Glob_(programming)">here</a>.',
order=1,
)
format: CsvFormat = Field(
title="Format",
description="The configuration options that are used to alter how to read incoming files that deviate from "
"the standard formatting.",
order=2,
)


class Config(AbstractFileBasedSpec):
"""
NOTE: When this Spec is changed, legacy_config_transformer.py must also be
modified to uptake the changes because it is responsible for converting
legacy GCS configs into file based configs using the File-Based CDK.
"""

@classmethod
def documentation_url(cls) -> AnyUrl:
"""
Returns the documentation URL.
"""
return AnyUrl("https://docs.airbyte.com/integrations/sources/gcs", scheme="https")

bucket: str = Field(title="Bucket", description="Name of the GCS bucket where the file(s) exist.", order=0)

service_account: str = Field(
title="Service Account Information.",
title="Service Account Information",
airbyte_secret=True,
description=(
"Enter your Google Cloud "
'<a href="https://cloud.google.com/iam/docs/creating-managing-service-account-keys#creating_service_account_keys">'
"service account key</a> in JSON format"
),
order=0,
)

bucket: str = Field(title="Bucket", description="Name of the GCS bucket where the file(s) exist.", order=1)

streams: List[SourceGCSStreamConfig] = Field(
title="The list of streams to sync",
description=(
"Each instance of this configuration defines a <a href=https://docs.airbyte.com/cloud/core-concepts#stream>stream</a>. "
"Use this to define which files belong in the stream, their format, and how they should be "
"parsed and validated. When sending data to warehouse destination such as Snowflake or "
"BigQuery, each stream is a separate table."
),
order=2,
)

@classmethod
def documentation_url(cls) -> AnyUrl:
"""
Returns the documentation URL.
"""
return AnyUrl("https://docs.airbyte.com/integrations/sources/gcs", scheme="https")

@staticmethod
def replace_enum_allOf_and_anyOf(schema):
"""
Replace allOf with anyOf when appropriate in the schema with one value.
"""
objects_to_check = schema["properties"]["streams"]["items"]["properties"]["format"]
if len(objects_to_check.get("allOf", [])) == 1:
objects_to_check["anyOf"] = objects_to_check.pop("allOf")

return super(Config, Config).replace_enum_allOf_and_anyOf(schema)
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

import itertools
import json
import logging
from contextlib import contextmanager
from datetime import timedelta
from datetime import datetime, timedelta
from io import IOBase
from typing import Iterable, List, Optional

Expand Down Expand Up @@ -33,6 +34,7 @@ class SourceGCSStreamReader(AbstractFileBasedStreamReader):
def __init__(self):
super().__init__()
self._gcs_client = None
self._config = None

@property
def config(self) -> Config:
Expand All @@ -43,38 +45,54 @@ def config(self, value: Config):
assert isinstance(value, Config), "Config must be an instance of the expected Config class."
self._config = value

@property
def gcs_client(self) -> storage.Client:
def _initialize_gcs_client(self):
if self.config is None:
raise ValueError("Source config is missing; cannot create the GCS client.")
if self._gcs_client is None:
credentials = service_account.Credentials.from_service_account_info(json.loads(self.config.service_account))
credentials = self._get_credentials()
self._gcs_client = storage.Client(credentials=credentials)
return self._gcs_client

def _get_credentials(self):
return service_account.Credentials.from_service_account_info(json.loads(self.config.service_account))

@property
def gcs_client(self) -> storage.Client:
return self._initialize_gcs_client()

def get_matching_files(self, globs: List[str], prefix: Optional[str], logger: logging.Logger) -> Iterable[RemoteFile]:
"""
Retrieve all files matching the specified glob patterns in GCS.
"""
try:
bucket = self.gcs_client.get_bucket(self.config.bucket)
remote_files = bucket.list_blobs(prefix=prefix)
start_date = (
datetime.strptime(self.config.start_date, self.DATE_TIME_FORMAT) if self.config and self.config.start_date else None
)
prefixes = [prefix] if prefix else self.get_prefixes_from_globs(globs or [])
globs = globs or [None]

for prefix, glob in itertools.product(prefixes, globs):
bucket = self.gcs_client.get_bucket(self.config.bucket)
blobs = bucket.list_blobs(prefix=prefix, match_glob=glob)
for blob in blobs:
last_modified = blob.updated.astimezone(pytz.utc).replace(tzinfo=None)

for remote_file in remote_files:
if FILE_FORMAT in remote_file.name.lower():
yield RemoteFile(
uri=remote_file.generate_signed_url(expiration=timedelta(hours=1), version="v4"),
last_modified=remote_file.updated.astimezone(pytz.utc).replace(tzinfo=None),
)
if FILE_FORMAT in blob.name.lower() and (not start_date or last_modified >= start_date):
uri = blob.generate_signed_url(expiration=timedelta(hours=1), version="v4")

yield RemoteFile(uri=uri, last_modified=last_modified)

except Exception as exc:
logger.error(f"Error while listing files: {str(exc)}")
raise ErrorListingFiles(
FileBasedSourceError.ERROR_LISTING_FILES,
source="gcs",
bucket=self.config.bucket,
prefix=prefix,
) from exc
self._handle_file_listing_error(exc, prefix, logger)

def _handle_file_listing_error(self, exc: Exception, prefix: str, logger: logging.Logger):
logger.error(f"Error while listing files: {str(exc)}")
raise ErrorListingFiles(
FileBasedSourceError.ERROR_LISTING_FILES,
source="gcs",
bucket=self.config.bucket,
prefix=prefix,
) from exc

@contextmanager
def open_file(self, file: RemoteFile, mode: FileReadMode, encoding: Optional[str], logger: logging.Logger) -> IOBase:
Expand Down

0 comments on commit b1aae06

Please sign in to comment.