Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for using multiple input folders #90

Open
wants to merge 10 commits into
base: alpha
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion client/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,10 @@ translation_config:
# - __DEFAULT_SCHEMA__

# Set this to True (default) to clean up the temporary data in the '.tmp_processed' folder after job finishes.
clean_up_tmp_files: False
clean_up_tmp_files: False

# [Optional field] The mapping from source file paths to optional target file paths.
# source_target_location_mapping: {
# "source_path1":"target_path",
# "source_path2":"",
# }
53 changes: 41 additions & 12 deletions client/dwh_migration_client/batch_sql_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import time
import uuid
from datetime import datetime
from os.path import dirname, join
from typing import Optional
from os.path import dirname, join, abspath
from typing import Optional, OrderedDict

from google.cloud.bigquery import migration_v2alpha as bigquery_migration_v2

Expand Down Expand Up @@ -89,20 +89,49 @@ def start_translation(self) -> None:
f"gs://{self.config.gcp_settings.gcs_bucket}", gcs_path, "output"
)
logging.info("Uploading inputs to gcs ...")
gcs_util.upload_directory(
local_input_dir,
self.config.gcp_settings.gcs_bucket,
join(gcs_path, "input"),
)

local_source_target_location_mapping = OrderedDict()
if self.config.translation_config.source_target_location_mapping:
# Use the input/output specified by the user arguments when not default.
if local_input_dir != "client/input":
local_source_target_location_mapping[
abspath(local_input_dir)
] = local_output_dir
local_source_target_location_mapping.update(
self.config.translation_config.source_target_location_mapping
)

if not local_source_target_location_mapping:
gcs_util.upload_directory(
local_input_dir,
self.config.gcp_settings.gcs_bucket,
join(gcs_path, "input"),
)
else:
# Upload using the map.
gcs_util.upload_full_directories(
local_source_target_location_mapping.keys(),
self.config.gcp_settings.gcs_bucket,
join(gcs_path, "input"),
)

logging.info("Start translation job...")
job_name = self.create_migration_workflow(gcs_input_path, gcs_output_path)
self._wait_until_job_finished(job_name)
logging.info("Downloading outputs...")
gcs_util.download_directory(
local_output_dir,
self.config.gcp_settings.gcs_bucket,
join(gcs_path, "output"),
)
if not local_source_target_location_mapping:
gcs_util.download_directory(
local_output_dir,
self.config.gcp_settings.gcs_bucket,
join(gcs_path, "output"),
)
else:
# Download using the map
gcs_util.download_directories(
local_source_target_location_mapping,
self.config.gcp_settings.gcs_bucket,
join(gcs_path, "output"),
)

if self.preprocessor is not None:
logging.info(
Expand Down
6 changes: 5 additions & 1 deletion client/dwh_migration_client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
from dataclasses import asdict, dataclass
from pprint import pformat
from typing import List, Optional
from typing import List, Dict, Optional

import yaml
from marshmallow import Schema, ValidationError, fields, post_load
Expand Down Expand Up @@ -49,6 +49,7 @@ class TranslationConfig:
default_database: Optional[str]
schema_search_path: Optional[List[str]]
clean_up_tmp_files: bool
source_target_location_mapping: Optional[Dict[str, str]]


class TranslationConfigSchema(Schema):
Expand All @@ -69,6 +70,9 @@ def _deserialize_translation_type(obj: str) -> TranslationType:
default_database = fields.String(load_default=None)
schema_search_path = fields.List(fields.String(), load_default=None)
clean_up_tmp_files = fields.Boolean(load_default=True)
source_target_location_mapping = fields.Dict(
keys=fields.String(), values=fields.String, load_default=None
)

@post_load
def build(self, data, **kwargs): # type: ignore[no-untyped-def] # pylint: disable=unused-argument
Expand Down
106 changes: 106 additions & 0 deletions client/dwh_migration_client/gcs_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import logging
import os
from os.path import abspath, basename, isdir, join
from typing import List, OrderedDict
from pathlib import Path

from google.cloud import storage
from google.cloud.exceptions import NotFound
Expand Down Expand Up @@ -57,6 +59,56 @@ def upload_directory(local_dir: str, bucket_name: str, gcs_path: str) -> None:
)


def upload_full_directories(
local_dir_list: List[str], bucket_name: str, gcs_path: str
) -> None:
"""Uploads all the files from local directories to a gcs bucket using the full paths.

Args:
local_dir_list: paths to the local directories
bucket_name: name of the gcs bucket. If the bucket doesn't exist, the method
tries to create one.
gcs_path: the path to the gcs directory that stores the files.
"""
for local_dir in local_dir_list:
upload_full_directory(local_dir, bucket_name, gcs_path)


def upload_full_directory(local_dir: str, bucket_name: str, gcs_path: str) -> None:
"""Uploads all the files from a local directory to a gcs bucket using the full/absolute file path.

Args:
local_dir: path to the local directory.
bucket_name: name of the gcs bucket. If the bucket doesn't exist, the method
tries to create one.
gcs_path: the path to the gcs directory that stores the files.
"""
assert isdir(local_dir), f"Can't find input directory {local_dir}."
client = storage.Client()

try:
logging.info("Get bucket %s", bucket_name)
bucket: Bucket = client.get_bucket(bucket_name)
except NotFound:
logging.info('The bucket "%s" does not exist, creating one...', bucket_name)
bucket = client.create_bucket(bucket_name)

dir_abs_path = abspath(local_dir)
for root, _, files in os.walk(dir_abs_path):
for name in files:
sub_dir = root[len(dir_abs_path) :]
if sub_dir.startswith("/"):
sub_dir = sub_dir[1:]
file_path = join(root, name)
logging.info('Uploading file "%s" to gcs...', file_path)
gcs_file_path = join(join(gcs_path, file_path[1:]))
blob = bucket.blob(gcs_file_path)
blob.upload_from_filename(file_path)
logging.info(
'Finished uploading input files to gcs "%s/%s".', bucket_name, gcs_path
)


def download_directory(local_dir: str, bucket_name: str, gcs_path: str) -> None:
"""Download all the files from a gcs bucket to a local directory.

Expand All @@ -79,3 +131,57 @@ def download_directory(local_dir: str, bucket_name: str, gcs_path: str) -> None:
blob.download_to_filename(file_path)

logging.info('Finished downloading. Output files are in "%s".', local_dir)


def download_directories(
local_dir_map: OrderedDict[str, str], bucket_name: str, gcs_path: str
) -> None:
"""Download all the files from specific directories in a gcs bucket to local directories.

Args:
local_dir_map: paths from the input locations to the local directories to store the downloaded files. It will
create the directory if it doesn't exist.
bucket_name: name of the gcs bucket.
gcs_path: the path to the gcs directory that stores the files.
"""
client = storage.Client()
blobs = client.list_blobs(bucket_name, prefix=gcs_path)
logging.info('Start downloading outputs from gcs "%s/%s"', bucket_name, gcs_path)
for blob in blobs:
file_name = basename(blob.name)
sub_dir = blob.name[len(gcs_path) + 1 : -len(file_name)]
# Determine local_dir based on what source dir it belongs to
local_dir = ""
for source_dir in local_dir_map.keys():
local_target = local_dir_map[source_dir]
if Path("/" + sub_dir).is_relative_to(source_dir):
if local_target is not None and local_target and local_target.strip():
local_dir = local_target
# Clean up sub_dir to no longer be an abs path so it will join properly for the output
if source_dir.startswith("/"):
source_dir = source_dir[1:]
if source_dir.endswith("/"):
source_dir = source_dir[:-1]
sub_dir = sub_dir[len(source_dir) + 1 :]
else:
local_dir = None
break
if local_dir is None:
logging.info(
'Skipping downloading "%s" because no output directory was selected.',
file_name,
)
continue
if not local_dir:
# The results files which should be output to the first valid output directory
for target_dir in local_dir_map.values():
if target_dir is not None and target_dir and target_dir.strip():
local_dir = target_dir
break
file_dir = join(local_dir, sub_dir)
os.makedirs(file_dir, exist_ok=True)
file_path = join(file_dir, file_name)
logging.info('Downloading output file to "%s"...', file_path)
blob.download_to_filename(file_path)

logging.info('Finished downloading. Output files are in "%s".', local_dir_map)
6 changes: 6 additions & 0 deletions client/google/cloud/bigquery/migration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,11 @@
from google.cloud.bigquery.migration_v2alpha.types.translation_config import RedshiftDialect
from google.cloud.bigquery.migration_v2alpha.types.translation_config import SnowflakeDialect
from google.cloud.bigquery.migration_v2alpha.types.translation_config import SourceEnv
from google.cloud.bigquery.migration_v2alpha.types.translation_config import SourceLocation
from google.cloud.bigquery.migration_v2alpha.types.translation_config import SourceTargetLocationMapping
from google.cloud.bigquery.migration_v2alpha.types.translation_config import SparkSQLDialect
from google.cloud.bigquery.migration_v2alpha.types.translation_config import SQLServerDialect
from google.cloud.bigquery.migration_v2alpha.types.translation_config import TargetLocation
from google.cloud.bigquery.migration_v2alpha.types.translation_config import TeradataDialect
from google.cloud.bigquery.migration_v2alpha.types.translation_config import TranslationConfigDetails
from google.cloud.bigquery.migration_v2alpha.types.translation_config import VerticaDialect
Expand Down Expand Up @@ -108,8 +111,11 @@
'RedshiftDialect',
'SnowflakeDialect',
'SourceEnv',
'SourceLocation',
'SourceTargetLocationMapping',
'SparkSQLDialect',
'SQLServerDialect',
'TargetLocation',
'TeradataDialect',
'TranslationConfigDetails',
'VerticaDialect',
Expand Down
6 changes: 6 additions & 0 deletions client/google/cloud/bigquery/migration_v2alpha/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,11 @@
from .types.translation_config import RedshiftDialect
from .types.translation_config import SnowflakeDialect
from .types.translation_config import SourceEnv
from .types.translation_config import SourceLocation
from .types.translation_config import SourceTargetLocationMapping
from .types.translation_config import SparkSQLDialect
from .types.translation_config import SQLServerDialect
from .types.translation_config import TargetLocation
from .types.translation_config import TeradataDialect
from .types.translation_config import TranslationConfigDetails
from .types.translation_config import VerticaDialect
Expand Down Expand Up @@ -110,8 +113,11 @@
'SQLServerDialect',
'SnowflakeDialect',
'SourceEnv',
'SourceLocation',
'SourceTargetLocationMapping',
'SparkSQLDialect',
'StartMigrationWorkflowRequest',
'TargetLocation',
'TeradataDialect',
'TeradataOptions',
'TimeInterval',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,11 @@
RedshiftDialect,
SnowflakeDialect,
SourceEnv,
SourceLocation,
SourceTargetLocationMapping,
SparkSQLDialect,
SQLServerDialect,
TargetLocation,
TeradataDialect,
TranslationConfigDetails,
VerticaDialect,
Expand Down Expand Up @@ -117,8 +120,11 @@
'RedshiftDialect',
'SnowflakeDialect',
'SourceEnv',
'SourceLocation',
'SourceTargetLocationMapping',
'SparkSQLDialect',
'SQLServerDialect',
'TargetLocation',
'TeradataDialect',
'TranslationConfigDetails',
'VerticaDialect',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ class MigrationTask(proto.Message):
Translation_Snowflake2BQ, Translation_Netezza2BQ,
Translation_AzureSynapse2BQ, Translation_Vertica2BQ,
Translation_SQLServer2BQ, Translation_Presto2BQ,
Translation_MySQL2BQ.
Translation_MySQL2BQ, Translation_Postgresql2BQ.
details (google.protobuf.any_pb2.Any):
DEPRECATED! Use one of the task_details below. The details
of the task. The type URL must be one of the supported task
Expand Down
Loading