Skip to content

Commit

Permalink
[ISSUE #30353] remove file_type from stream config (#30453)
Browse files Browse the repository at this point in the history
  • Loading branch information
maxi297 authored Sep 18, 2023
1 parent 2607a5f commit b6836ad
Show file tree
Hide file tree
Showing 22 changed files with 630 additions and 709 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _check_list_files(self, stream: "AbstractFileBasedStream") -> List[RemoteFil
return files

def _check_parse_record(self, stream: "AbstractFileBasedStream", file: RemoteFile, logger: logging.Logger) -> None:
parser = stream.get_parser(stream.config.file_type)
parser = stream.get_parser()

try:
record = next(iter(parser.parse_records(stream.config, file, self.stream_reader, logger, discovered_schema=None)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#

from enum import Enum
from typing import Any, List, Mapping, Optional, Type, Union
from typing import Any, List, Mapping, Optional, Union

from airbyte_cdk.sources.file_based.config.avro_format import AvroFormat
from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat
Expand All @@ -16,9 +16,6 @@
PrimaryKeyType = Optional[Union[str, List[str]]]


VALID_FILE_TYPES: Mapping[str, Type[BaseModel]] = {"avro": AvroFormat, "csv": CsvFormat, "jsonl": JsonlFormat, "parquet": ParquetFormat}


class ValidationPolicy(Enum):
emit_record = "Emit Record"
skip_record = "Skip Record"
Expand All @@ -27,7 +24,6 @@ class ValidationPolicy(Enum):

class FileBasedStreamConfig(BaseModel):
name: str = Field(title="Name", description="The name of the stream.")
file_type: str = Field(title="File Type", description="The data file type that is being extracted for a stream.")
globs: Optional[List[str]] = Field(
title="Globs",
description='The pattern used to specify which files should be selected from the file system. For more information on glob pattern matching look <a href="https://en.wikipedia.org/wiki/Glob_(programming)">here</a>.',
Expand All @@ -54,7 +50,7 @@ class FileBasedStreamConfig(BaseModel):
description="When the state history of the file store is full, syncs will only read files that were last modified in the provided day range.",
default=3,
)
format: Optional[Union[AvroFormat, CsvFormat, JsonlFormat, ParquetFormat]] = Field(
format: Union[AvroFormat, CsvFormat, JsonlFormat, ParquetFormat] = Field(
title="Format",
description="The configuration options that are used to alter how to read incoming files that deviate from the standard formatting.",
)
Expand All @@ -64,37 +60,6 @@ class FileBasedStreamConfig(BaseModel):
default=False,
)

@validator("file_type", pre=True)
def validate_file_type(cls, v: str) -> str:
if v not in VALID_FILE_TYPES:
raise ValueError(f"Format filetype {v} is not a supported file type")
return v

@classmethod
def _transform_legacy_config(cls, legacy_config: Mapping[str, Any], file_type: str) -> Mapping[str, Any]:
if file_type.casefold() not in VALID_FILE_TYPES:
raise ValueError(f"Format filetype {file_type} is not a supported file type")
if file_type.casefold() == "parquet" or file_type.casefold() == "avro":
legacy_config = cls._transform_legacy_parquet_or_avro_config(legacy_config)
return {file_type: VALID_FILE_TYPES[file_type.casefold()].parse_obj({key: val for key, val in legacy_config.items()})}

@classmethod
def _transform_legacy_parquet_or_avro_config(cls, config: Mapping[str, Any]) -> Mapping[str, Any]:
"""
The legacy parquet parser converts decimal fields to numbers. This isn't desirable because it can lead to precision loss.
To avoid introducing a breaking change with the new default, we will set decimal_as_float to True in the legacy configs.
"""
filetype = config.get("filetype")
if filetype != "parquet" and filetype != "avro":
raise ValueError(
f"Expected {filetype} format, got {config}. This is probably due to a CDK bug. Please reach out to the Airbyte team for support."
)
if config.get("decimal_as_float"):
raise ValueError(
f"Received legacy {filetype} file form with 'decimal_as_float' set. This is unexpected. Please reach out to the Airbyte team for support."
)
return {**config, **{"decimal_as_float": True}}

@validator("input_schema", pre=True)
def validate_input_schema(cls, v: Optional[str]) -> Optional[str]:
if v:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
catalog_path: Optional[str] = None,
availability_strategy: Optional[AbstractFileBasedAvailabilityStrategy] = None,
discovery_policy: AbstractDiscoveryPolicy = DefaultDiscoveryPolicy(),
parsers: Mapping[str, FileTypeParser] = default_parsers,
parsers: Mapping[Type[Any], FileTypeParser] = default_parsers,
validation_policies: Mapping[ValidationPolicy, AbstractSchemaValidationPolicy] = DEFAULT_SCHEMA_VALIDATION_POLICIES,
cursor_cls: Type[AbstractFileBasedCursor] = DefaultFileBasedCursor,
):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from typing import Mapping
from typing import Any, Mapping, Type

from airbyte_cdk.sources.file_based.config.avro_format import AvroFormat
from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat
from airbyte_cdk.sources.file_based.config.jsonl_format import JsonlFormat
from airbyte_cdk.sources.file_based.config.parquet_format import ParquetFormat

from .avro_parser import AvroParser
from .csv_parser import CsvParser
from .file_type_parser import FileTypeParser
from .jsonl_parser import JsonlParser
from .parquet_parser import ParquetParser

default_parsers: Mapping[str, FileTypeParser] = {
"avro": AvroParser(),
"csv": CsvParser(),
"jsonl": JsonlParser(),
"parquet": ParquetParser(),
default_parsers: Mapping[Type[Any], FileTypeParser] = {
AvroFormat: AvroParser(),
CsvFormat: CsvParser(),
JsonlFormat: JsonlParser(),
ParquetFormat: ParquetParser(),
}

__all__ = ["AvroParser", "CsvParser", "JsonlParser", "ParquetParser", "default_parsers"]
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async def infer_schema(
stream_reader: AbstractFileBasedStreamReader,
logger: logging.Logger,
) -> SchemaType:
avro_format = config.format or AvroFormat()
avro_format = config.format
if not isinstance(avro_format, AvroFormat):
raise ValueError(f"Expected ParquetFormat, got {avro_format}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def _no_cast(row: Mapping[str, str]) -> Mapping[str, str]:


def _extract_format(config: FileBasedStreamConfig) -> CsvFormat:
config_format = config.format or CsvFormat()
config_format = config.format
if not isinstance(config_format, CsvFormat):
raise ValueError(f"Invalid format config: {config_format}")
return config_format
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def infer_schema(
stream_reader: AbstractFileBasedStreamReader,
logger: logging.Logger,
) -> SchemaType:
parquet_format = config.format or ParquetFormat()
parquet_format = config.format
if not isinstance(parquet_format, ParquetFormat):
raise ValueError(f"Expected ParquetFormat, got {parquet_format}")

Expand All @@ -54,7 +54,7 @@ def parse_records(
logger: logging.Logger,
discovered_schema: Optional[Mapping[str, SchemaType]],
) -> Iterable[Dict[str, Any]]:
parquet_format = config.format or ParquetFormat()
parquet_format = config.format
if not isinstance(parquet_format, ParquetFormat):
logger.info(f"Expected ParquetFormat, got {parquet_format}")
raise ConfigValidationError(FileBasedSourceError.CONFIG_VALIDATION_ERROR)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from abc import abstractmethod
from functools import cached_property, lru_cache
from typing import Any, Dict, Iterable, List, Mapping, Optional
from typing import Any, Dict, Iterable, List, Mapping, Optional, Type

from airbyte_cdk.models import SyncMode
from airbyte_cdk.sources.file_based.availability_strategy import AbstractFileBasedAvailabilityStrategy
Expand Down Expand Up @@ -42,7 +42,7 @@ def __init__(
stream_reader: AbstractFileBasedStreamReader,
availability_strategy: AbstractFileBasedAvailabilityStrategy,
discovery_policy: AbstractDiscoveryPolicy,
parsers: Dict[str, FileTypeParser],
parsers: Dict[Type[Any], FileTypeParser],
validation_policy: AbstractSchemaValidationPolicy,
):
super().__init__()
Expand Down Expand Up @@ -121,11 +121,11 @@ def infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]:
"""
...

def get_parser(self, file_type: str) -> FileTypeParser:
def get_parser(self) -> FileTypeParser:
try:
return self._parsers[file_type]
return self._parsers[type(self.config.format)]
except KeyError:
raise UndefinedParserError(FileBasedSourceError.UNDEFINED_PARSER, stream=self.name, file_type=file_type)
raise UndefinedParserError(FileBasedSourceError.UNDEFINED_PARSER, stream=self.name, format=type(self.config.format))

def record_passes_validation_policy(self, record: Mapping[str, Any]) -> bool:
if self.validation_policy:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import asyncio
import itertools
import traceback
from copy import deepcopy
from functools import cache
from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Set, Union

Expand Down Expand Up @@ -79,7 +80,7 @@ def read_records_from_slice(self, stream_slice: StreamSlice) -> Iterable[Airbyte
# On read requests we should always have the catalog available
raise MissingSchemaError(FileBasedSourceError.MISSING_SCHEMA, stream=self.name)
# The stream only supports a single file type, so we can use the same parser for all files
parser = self.get_parser(self.config.file_type)
parser = self.get_parser()
for file in stream_slice["files"]:
# only serialize the datetime once
file_datetime_string = file.last_modified.strftime(self.DATE_TIME_FORMAT)
Expand Down Expand Up @@ -190,7 +191,7 @@ def _get_raw_json_schema(self) -> JsonSchema:
if not inferred_schema:
raise InvalidSchemaError(
FileBasedSourceError.INVALID_SCHEMA_ERROR,
details=f"Empty schema. Please check that the files are valid {self.config.file_type}",
details=f"Empty schema. Please check that the files are valid for format {self.config.format}",
stream=self.name,
)

Expand All @@ -210,7 +211,8 @@ def list_files(self) -> List[RemoteFile]:
def infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]:
loop = asyncio.get_event_loop()
schema = loop.run_until_complete(self._infer_schema(files))
return self._fill_nulls(schema)
# as infer schema returns a Mapping that is assumed to be immutable, we need to create a deepcopy to avoid modifying the reference
return self._fill_nulls(deepcopy(schema))

@staticmethod
def _fill_nulls(schema: Mapping[str, Any]) -> Mapping[str, Any]:
Expand Down Expand Up @@ -258,11 +260,11 @@ async def _infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]:

async def _infer_file_schema(self, file: RemoteFile) -> SchemaType:
try:
return await self.get_parser(self.config.file_type).infer_schema(self.config, file, self._stream_reader, self.logger)
return await self.get_parser().infer_schema(self.config, file, self._stream_reader, self.logger)
except Exception as exc:
raise SchemaInferenceError(
FileBasedSourceError.SCHEMA_INFERENCE_ERROR,
file=file.uri,
stream_file_type=self.config.file_type,
format=str(self.config.format),
stream=self.name,
) from exc
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,17 @@
id="test_decimal_missing_precision"),
pytest.param(_default_avro_format, {"type": "bytes", "logicalType": "decimal", "precision": 9}, None, ValueError,
id="test_decimal_missing_scale"),
pytest.param(_default_avro_format, {"type": "bytes", "logicalType": "uuid"}, {"type": ["null", "string"]}, None, id="test_uuid"),
pytest.param(_default_avro_format, {"type": "int", "logicalType": "date"}, {"type": ["null", "string"], "format": "date"}, None,
pytest.param(_default_avro_format, {"type": "bytes", "logicalType": "uuid"}, {"type": "string"}, None, id="test_uuid"),
pytest.param(_default_avro_format, {"type": "int", "logicalType": "date"}, {"type": "string", "format": "date"}, None,
id="test_date"),
pytest.param(_default_avro_format, {"type": "int", "logicalType": "time-millis"}, {"type": ["null", "integer"]}, None, id="test_time_millis"),
pytest.param(_default_avro_format, {"type": "long", "logicalType": "time-micros"}, {"type": ["null", "integer"]}, None,
pytest.param(_default_avro_format, {"type": "int", "logicalType": "time-millis"}, {"type": "integer"}, None, id="test_time_millis"),
pytest.param(_default_avro_format, {"type": "long", "logicalType": "time-micros"}, {"type": "integer"}, None,
id="test_time_micros"),
pytest.param(
_default_avro_format,
{"type": "long", "logicalType": "timestamp-millis"}, {"type": ["null", "string"], "format": "date-time"}, None, id="test_timestamp_millis"
{"type": "long", "logicalType": "timestamp-millis"}, {"type": "string", "format": "date-time"}, None, id="test_timestamp_millis"
),
pytest.param(_default_avro_format, {"type": "long", "logicalType": "timestamp-micros"}, {"type": ["null", "string"]}, None,
pytest.param(_default_avro_format, {"type": "long", "logicalType": "timestamp-micros"}, {"type": "string"}, None,
id="test_timestamp_micros"),
pytest.param(
_default_avro_format,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@
"streams": [
{
"name": "stream1",
"file_type": "avro",
"format": {"filetype": "avro"},
"globs": ["*"],
"validation_policy": "Emit Record",
}
Expand Down Expand Up @@ -266,7 +266,7 @@
"streams": [
{
"name": "stream1",
"file_type": "avro",
"format": {"filetype": "avro"},
"globs": ["*"],
"validation_policy": "Emit Record",
}
Expand Down Expand Up @@ -362,7 +362,7 @@
"streams": [
{
"name": "stream1",
"file_type": "avro",
"format": {"filetype": "avro"},
"globs": ["*"],
"validation_policy": "Emit Record",
}
Expand Down Expand Up @@ -463,13 +463,13 @@
"streams": [
{
"name": "songs_stream",
"file_type": "avro",
"format": {"filetype": "avro"},
"globs": ["*_songs.avro"],
"validation_policy": "Emit Record",
},
{
"name": "festivals_stream",
"file_type": "avro",
"format": {"filetype": "avro"},
"globs": ["*_festivals.avro"],
"validation_policy": "Emit Record",
},
Expand Down Expand Up @@ -629,7 +629,6 @@
"streams": [
{
"name": "stream1",
"file_type": "avro",
"globs": ["*"],
"validation_policy": "Emit Record",
"format": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"streams": [
{
"name": "stream1",
"file_type": "csv",
"format": {"filetype": "csv"},
"globs": ["*.csv"],
"validation_policy": "Emit Record",
}
Expand Down Expand Up @@ -55,13 +55,13 @@
"streams": [
{
"name": "stream1",
"file_type": "csv",
"format": {"filetype": "csv"},
"globs": ["*.csv", "*.gz"],
"validation_policy": "Emit Record",
},
{
"name": "stream2",
"file_type": "csv",
"format": {"filetype": "csv"},
"globs": ["*.csv", "*.gz"],
"validation_policy": "Emit Record",
}
Expand All @@ -79,7 +79,7 @@
"streams": [
{
"name": "stream1",
"file_type": "csv",
"format": {"filetype": "csv"},
"globs": ["*"],
"validation_policy": "Emit Record",
}
Expand Down Expand Up @@ -109,7 +109,7 @@
"streams": [
{
"name": "stream1",
"file_type": "csv",
"format": {"filetype": "csv"},
"globs": ["*.csv"],
"validation_policy": "Emit Record",
"input_schema": '{"col1": "string", "col2": "string"}',
Expand Down Expand Up @@ -158,7 +158,7 @@
"streams": [
{
"name": "stream1",
"file_type": "csv",
"format": {"filetype": "csv"},
"globs": ["*.csv"],
"validation_policy": "always_fail",
"input_schema": '{"col1": "number", "col2": "string"}',
Expand All @@ -179,13 +179,13 @@
"streams": [
{
"name": "stream1",
"file_type": "csv",
"format": {"filetype": "csv"},
"globs": ["*.csv"],
"validation_policy": "Emit Record",
},
{
"name": "stream2",
"file_type": "jsonl",
"format": {"filetype": "jsonl"},
"globs": ["*.csv"],
"validation_policy": "Emit Record",
}
Expand Down
Loading

0 comments on commit b6836ad

Please sign in to comment.