Skip to content

Commit

Permalink
Feat: Multiple perf improvements and viztracer performance profiling (
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronsteers authored Jul 17, 2024
1 parent 326683c commit 1acd6df
Show file tree
Hide file tree
Showing 12 changed files with 345 additions and 101 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Viztracer log files
viztracer_report.json

# Packaged docs
docs/*.zip

Expand Down
8 changes: 8 additions & 0 deletions .viztracerrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
; Default config settings for viztracer
; https://viztracer.readthedocs.io/en/latest/basic_usage.html#configuration-file

[default]
max_stack_depth = 20
unique_output_file = True
output_file = viztracer_report.json
tracer_entries = 5_000_000
14 changes: 13 additions & 1 deletion airbyte/_future_cdk/sql_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ def __init__(
)
self.type_converter = self.type_converter_class()
self._cached_table_definitions: dict[str, sqlalchemy.Table] = {}

self._known_schemas_list: list[str] = []
self._ensure_schema_exists()

# Public interface:
Expand Down Expand Up @@ -305,6 +307,10 @@ def _ensure_schema_exists(
) -> None:
"""Return a new (unique) temporary table name."""
schema_name = self.sql_config.schema_name

if self._known_schemas_list and self.sql_config.schema_name in self._known_schemas_list:
return # Already exists

if schema_name in self._get_schemas_list():
return

Expand Down Expand Up @@ -372,17 +378,23 @@ def _get_tables_list(
def _get_schemas_list(
self,
database_name: str | None = None,
*,
force_refresh: bool = False,
) -> list[str]:
"""Return a list of all tables in the database."""
if not force_refresh and self._known_schemas_list:
return self._known_schemas_list

inspector: Inspector = sqlalchemy.inspect(self.get_sql_engine())
database_name = database_name or self.database_name
found_schemas = inspector.get_schema_names()
return [
self._known_schemas_list = [
found_schema.split(".")[-1].strip('"')
for found_schema in found_schemas
if "." not in found_schema
or (found_schema.split(".")[0].lower().strip('"') == database_name.lower())
]
return self._known_schemas_list

def _ensure_final_table_exists(
self,
Expand Down
2 changes: 1 addition & 1 deletion airbyte/_processors/file/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)


DEFAULT_BATCH_SIZE = 10000
DEFAULT_BATCH_SIZE = 100_000


class FileWriterBase(abc.ABC):
Expand Down
17 changes: 16 additions & 1 deletion airbyte/_processors/sql/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from __future__ import annotations

from concurrent.futures import ThreadPoolExecutor
from textwrap import dedent, indent
from typing import TYPE_CHECKING

Expand All @@ -12,6 +13,7 @@
from snowflake import connector
from snowflake.sqlalchemy import URL, VARIANT

from airbyte import exceptions as exc
from airbyte._future_cdk import SqlProcessorBase
from airbyte._future_cdk.sql_processor import SqlConfig
from airbyte._processors.file.jsonl import JsonlWriter
Expand All @@ -26,6 +28,9 @@
from sqlalchemy.engine import Connection


MAX_UPLOAD_THREADS = 8


class SnowflakeConfig(SqlConfig):
"""Configuration for the Snowflake cache."""

Expand Down Expand Up @@ -120,10 +125,20 @@ def _write_files_to_new_table(
def path_str(path: Path) -> str:
return str(path.absolute()).replace("\\", "\\\\")

for file_path in files:
def upload_file(file_path: Path) -> None:
query = f"PUT 'file://{path_str(file_path)}' {internal_sf_stage_name};"
self._execute_sql(query)

# Upload files in parallel
with ThreadPoolExecutor(max_workers=MAX_UPLOAD_THREADS) as executor:
try:
executor.map(upload_file, files)
except Exception as e:
raise exc.PyAirbyteInternalError(
message="Failed to upload batch files to Snowflake.",
context={"files": [str(f) for f in files]},
) from e

columns_list = [
self._quote_identifier(c)
for c in list(self._get_sql_column_definitions(stream_name).keys())
Expand Down
2 changes: 2 additions & 0 deletions airbyte/_util/name_normalizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

import abc
import functools
import re
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -48,6 +49,7 @@ class LowerCaseNormalizer(NameNormalizerBase):
"""A name normalizer that converts names to lower case."""

@staticmethod
@functools.cache
def normalize(name: str) -> str:
"""Return the normalized name.
Expand Down
4 changes: 3 additions & 1 deletion airbyte/caches/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from __future__ import annotations

from typing import NoReturn

from pydantic import PrivateAttr

from airbyte._processors.sql.bigquery import BigQueryConfig, BigQuerySqlProcessor
Expand All @@ -36,7 +38,7 @@ def get_arrow_dataset(
stream_name: str,
*,
max_chunk_size: int = DEFAULT_ARROW_MAX_CHUNK_SIZE,
) -> None:
) -> NoReturn:
"""Raises NotImplementedError; BigQuery doesn't support `pd.read_sql_table`.
https://github.com/airbytehq/PyAirbyte/issues/165
"""
Expand Down
19 changes: 11 additions & 8 deletions airbyte/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,20 +85,20 @@ def _to_time_str(timestamp: float) -> str:
return datetime_obj.strftime("%H:%M:%S")


def _get_elapsed_time_str(seconds: int) -> str:
def _get_elapsed_time_str(seconds: float) -> str:
"""Return duration as a string.
Seconds are included until 10 minutes is exceeded.
Minutes are always included after 1 minute elapsed.
Hours are always included after 1 hour elapsed.
"""
if seconds <= 60: # noqa: PLR2004 # Magic numbers OK here.
return f"{seconds} seconds"
return f"{seconds:.0f} seconds"

if seconds < 60 * 10:
minutes = seconds // 60
seconds %= 60
return f"{minutes}min {seconds}s"
return f"{minutes}min {seconds:.0f}s"

if seconds < 60 * 60:
minutes = seconds // 60
Expand Down Expand Up @@ -280,12 +280,12 @@ def elapsed_seconds_since_last_update(self) -> float | None:
return time.time() - self.last_update_time

@property
def elapsed_read_seconds(self) -> int:
def elapsed_read_seconds(self) -> float:
"""Return the number of seconds elapsed since the read operation started."""
if self.read_end_time is None:
return int(time.time() - self.read_start_time)
return time.time() - self.read_start_time

return int(self.read_end_time - self.read_start_time)
return self.read_end_time - self.read_start_time

@property
def elapsed_read_time_string(self) -> str:
Expand Down Expand Up @@ -366,7 +366,7 @@ def update_display(self, *, force_refresh: bool = False) -> None:
if (
not force_refresh
and self.last_update_time # if not set, then we definitely need to update
and cast(float, self.elapsed_seconds_since_last_update) < 0.5 # noqa: PLR2004
and cast(float, self.elapsed_seconds_since_last_update) < 0.8 # noqa: PLR2004
):
return

Expand Down Expand Up @@ -396,7 +396,10 @@ def _get_status_message(self) -> str:
start_time_str = _to_time_str(self.read_start_time)
records_per_second: float = 0.0
if self.elapsed_read_seconds > 0:
records_per_second = round(self.total_records_read / self.elapsed_read_seconds, 1)
records_per_second = round(
float(self.total_records_read) / self.elapsed_read_seconds,
ndigits=1,
)
status_message = (
f"## Read Progress\n\n"
f"Started reading at {start_time_str}.\n\n"
Expand Down
2 changes: 1 addition & 1 deletion examples/run_bigquery_faker.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def main() -> None:

cache = BigQueryCache(
project_name=bigquery_destination_secret["project_id"],
dataset_name=bigquery_destination_secret["dataset_id"],
dataset_name=bigquery_destination_secret.get("dataset_id", "pyairbyte_integtest"),
credentials_path=temp.name,
)

Expand Down
136 changes: 136 additions & 0 deletions examples/run_perf_test_reads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
"""
Simple script to get performance profile of read throughput.
This script accepts a single argument `-e=SCALE` as a power of 10.
-e=2 is equivalent to 500 records.
-e=3 is equivalent to 5_000 records.
-e=4 is equivalent to 50_000 records.
-e=5 is equivalent to 500_000 records.
-e=6 is equivalent to 5_000_000 records.
Use smaller values of `e` (2-3) to understand read and overhead costs.
Use larger values of `e` (4-5) to understand write throughput at scale.
For performance profiling, use `viztracer` to generate a flamegraph:
```
poetry run viztracer --open -- ./examples/run_perf_test_reads.py -e=3
poetry run viztracer --open -- ./examples/run_perf_test_reads.py -e=5
```
To run without profiling, prefix script name with `poetry run python`:
```
# Run with 5_000 records
poetry run python ./examples/run_perf_test_reads.py -e=3
# Run with 500_000 records
poetry run python ./examples/run_perf_test_reads.py -e=5
# Load 5_000 records to Snowflake
poetry run python ./examples/run_perf_test_reads.py -e=3 --cache=snowflake
# Load 5_000 records to BigQuery
poetry run python ./examples/run_perf_test_reads.py -e=3 --cache=bigquery
```
"""

from __future__ import annotations

import argparse
import tempfile

import airbyte as ab
from airbyte.caches import BigQueryCache, CacheBase, SnowflakeCache
from airbyte.secrets.google_gsm import GoogleGSMSecretManager


AIRBYTE_INTERNAL_GCP_PROJECT = "dataline-integration-testing"


def get_gsm_secret_json(secret_name: str) -> dict:
secret_mgr = GoogleGSMSecretManager(
project=AIRBYTE_INTERNAL_GCP_PROJECT,
credentials_json=ab.get_secret("GCP_GSM_CREDENTIALS"),
)
secret = secret_mgr.get_secret(
secret_name=secret_name,
)
assert secret is not None, "Secret not found."
return secret.parse_json()


def main(
e: int = 4,
cache_type: str = "duckdb",
) -> None:
e = e or 4
cache_type = cache_type or "duckdb"

cache: CacheBase
if cache_type == "duckdb":
cache = ab.new_local_cache()

elif cache_type == "snowflake":
secret_config = get_gsm_secret_json(
secret_name="AIRBYTE_LIB_SNOWFLAKE_CREDS",
)
cache = SnowflakeCache(
account=secret_config["account"],
username=secret_config["username"],
password=secret_config["password"],
database=secret_config["database"],
warehouse=secret_config["warehouse"],
role=secret_config["role"],
)

elif cache_type == "bigquery":
temp = tempfile.NamedTemporaryFile(mode="w+", delete=False, encoding="utf-8")
secret_config = get_gsm_secret_json(
secret_name="SECRET_DESTINATION-BIGQUERY_CREDENTIALS__CREDS",
)
try:
# Write credentials to the temp file
temp.write(secret_config["credentials_json"])
temp.flush()
finally:
temp.close()

cache = BigQueryCache(
project_name=secret_config["project_id"],
dataset_name=secret_config.get("dataset_id", "pyairbyte_integtest"),
credentials_path=temp.name,
)

source = ab.get_source(
"source-faker",
config={"count": 5 * (10**e)},
install_if_missing=False,
streams=["purchases"],
)
source.check()

source.read(cache)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run performance test reads.")
parser.add_argument(
"-e",
type=int,
help=(
"The scale, as a power of 10. "
"Recommended values: 2-3 (500 or 5_000) for read and overhead costs, "
" 4-6 (50K or 5MM) for write throughput."
),
)
parser.add_argument(
"--cache",
type=str,
help="The cache type to use.",
choices=["duckdb", "snowflake", "bigquery"],
default="duckdb",
)
args = parser.parse_args()

main(e=args.e, cache_type=args.cache)
Loading

0 comments on commit 1acd6df

Please sign in to comment.