diff --git a/client/config.yaml b/client/config.yaml index ccba4bc16..f55a76755 100644 --- a/client/config.yaml +++ b/client/config.yaml @@ -38,4 +38,10 @@ translation_config: # - __DEFAULT_SCHEMA__ # Set this to True (default) to clean up the temporary data in the '.tmp_processed' folder after job finishes. - clean_up_tmp_files: False \ No newline at end of file + clean_up_tmp_files: False + + # [Optional field] The mapping from source file paths to optional target file paths. + # source_target_location_mapping: { + # "source_path1":"target_path", + # "source_path2":"", + # } \ No newline at end of file diff --git a/client/dwh_migration_client/batch_sql_translator.py b/client/dwh_migration_client/batch_sql_translator.py index 99c0bd51c..eeb95f98c 100644 --- a/client/dwh_migration_client/batch_sql_translator.py +++ b/client/dwh_migration_client/batch_sql_translator.py @@ -21,8 +21,8 @@ import time import uuid from datetime import datetime -from os.path import dirname, join -from typing import Optional +from os.path import dirname, join, abspath +from typing import Optional, OrderedDict from google.cloud.bigquery import migration_v2alpha as bigquery_migration_v2 @@ -89,20 +89,49 @@ def start_translation(self) -> None: f"gs://{self.config.gcp_settings.gcs_bucket}", gcs_path, "output" ) logging.info("Uploading inputs to gcs ...") - gcs_util.upload_directory( - local_input_dir, - self.config.gcp_settings.gcs_bucket, - join(gcs_path, "input"), - ) + + local_source_target_location_mapping = OrderedDict() + if self.config.translation_config.source_target_location_mapping: + # Use the input/output specified by the user arguments when not default. + if local_input_dir != "client/input": + local_source_target_location_mapping[ + abspath(local_input_dir) + ] = local_output_dir + local_source_target_location_mapping.update( + self.config.translation_config.source_target_location_mapping + ) + + if not local_source_target_location_mapping: + gcs_util.upload_directory( + local_input_dir, + self.config.gcp_settings.gcs_bucket, + join(gcs_path, "input"), + ) + else: + # Upload using the map. + gcs_util.upload_full_directories( + local_source_target_location_mapping.keys(), + self.config.gcp_settings.gcs_bucket, + join(gcs_path, "input"), + ) + logging.info("Start translation job...") job_name = self.create_migration_workflow(gcs_input_path, gcs_output_path) self._wait_until_job_finished(job_name) logging.info("Downloading outputs...") - gcs_util.download_directory( - local_output_dir, - self.config.gcp_settings.gcs_bucket, - join(gcs_path, "output"), - ) + if not local_source_target_location_mapping: + gcs_util.download_directory( + local_output_dir, + self.config.gcp_settings.gcs_bucket, + join(gcs_path, "output"), + ) + else: + # Download using the map + gcs_util.download_directories( + local_source_target_location_mapping, + self.config.gcp_settings.gcs_bucket, + join(gcs_path, "output"), + ) if self.preprocessor is not None: logging.info( diff --git a/client/dwh_migration_client/config.py b/client/dwh_migration_client/config.py index 66c9e6642..62311ec05 100644 --- a/client/dwh_migration_client/config.py +++ b/client/dwh_migration_client/config.py @@ -16,7 +16,7 @@ import logging from dataclasses import asdict, dataclass from pprint import pformat -from typing import List, Optional +from typing import List, Dict, Optional import yaml from marshmallow import Schema, ValidationError, fields, post_load @@ -49,6 +49,7 @@ class TranslationConfig: default_database: Optional[str] schema_search_path: Optional[List[str]] clean_up_tmp_files: bool + source_target_location_mapping: Optional[Dict[str, str]] class TranslationConfigSchema(Schema): @@ -69,6 +70,9 @@ def _deserialize_translation_type(obj: str) -> TranslationType: default_database = fields.String(load_default=None) schema_search_path = fields.List(fields.String(), load_default=None) clean_up_tmp_files = fields.Boolean(load_default=True) + source_target_location_mapping = fields.Dict( + keys=fields.String(), values=fields.String, load_default=None + ) @post_load def build(self, data, **kwargs): # type: ignore[no-untyped-def] # pylint: disable=unused-argument diff --git a/client/dwh_migration_client/gcs_util.py b/client/dwh_migration_client/gcs_util.py index 98199ba94..b113501ac 100644 --- a/client/dwh_migration_client/gcs_util.py +++ b/client/dwh_migration_client/gcs_util.py @@ -16,6 +16,8 @@ import logging import os from os.path import abspath, basename, isdir, join +from typing import List, OrderedDict +from pathlib import Path from google.cloud import storage from google.cloud.exceptions import NotFound @@ -57,6 +59,56 @@ def upload_directory(local_dir: str, bucket_name: str, gcs_path: str) -> None: ) +def upload_full_directories( + local_dir_list: List[str], bucket_name: str, gcs_path: str +) -> None: + """Uploads all the files from local directories to a gcs bucket using the full paths. + + Args: + local_dir_list: paths to the local directories + bucket_name: name of the gcs bucket. If the bucket doesn't exist, the method + tries to create one. + gcs_path: the path to the gcs directory that stores the files. + """ + for local_dir in local_dir_list: + upload_full_directory(local_dir, bucket_name, gcs_path) + + +def upload_full_directory(local_dir: str, bucket_name: str, gcs_path: str) -> None: + """Uploads all the files from a local directory to a gcs bucket using the full/absolute file path. + + Args: + local_dir: path to the local directory. + bucket_name: name of the gcs bucket. If the bucket doesn't exist, the method + tries to create one. + gcs_path: the path to the gcs directory that stores the files. + """ + assert isdir(local_dir), f"Can't find input directory {local_dir}." + client = storage.Client() + + try: + logging.info("Get bucket %s", bucket_name) + bucket: Bucket = client.get_bucket(bucket_name) + except NotFound: + logging.info('The bucket "%s" does not exist, creating one...', bucket_name) + bucket = client.create_bucket(bucket_name) + + dir_abs_path = abspath(local_dir) + for root, _, files in os.walk(dir_abs_path): + for name in files: + sub_dir = root[len(dir_abs_path) :] + if sub_dir.startswith("/"): + sub_dir = sub_dir[1:] + file_path = join(root, name) + logging.info('Uploading file "%s" to gcs...', file_path) + gcs_file_path = join(join(gcs_path, file_path[1:])) + blob = bucket.blob(gcs_file_path) + blob.upload_from_filename(file_path) + logging.info( + 'Finished uploading input files to gcs "%s/%s".', bucket_name, gcs_path + ) + + def download_directory(local_dir: str, bucket_name: str, gcs_path: str) -> None: """Download all the files from a gcs bucket to a local directory. @@ -79,3 +131,57 @@ def download_directory(local_dir: str, bucket_name: str, gcs_path: str) -> None: blob.download_to_filename(file_path) logging.info('Finished downloading. Output files are in "%s".', local_dir) + + +def download_directories( + local_dir_map: OrderedDict[str, str], bucket_name: str, gcs_path: str +) -> None: + """Download all the files from specific directories in a gcs bucket to local directories. + + Args: + local_dir_map: paths from the input locations to the local directories to store the downloaded files. It will + create the directory if it doesn't exist. + bucket_name: name of the gcs bucket. + gcs_path: the path to the gcs directory that stores the files. + """ + client = storage.Client() + blobs = client.list_blobs(bucket_name, prefix=gcs_path) + logging.info('Start downloading outputs from gcs "%s/%s"', bucket_name, gcs_path) + for blob in blobs: + file_name = basename(blob.name) + sub_dir = blob.name[len(gcs_path) + 1 : -len(file_name)] + # Determine local_dir based on what source dir it belongs to + local_dir = "" + for source_dir in local_dir_map.keys(): + local_target = local_dir_map[source_dir] + if Path("/" + sub_dir).is_relative_to(source_dir): + if local_target is not None and local_target and local_target.strip(): + local_dir = local_target + # Clean up sub_dir to no longer be an abs path so it will join properly for the output + if source_dir.startswith("/"): + source_dir = source_dir[1:] + if source_dir.endswith("/"): + source_dir = source_dir[:-1] + sub_dir = sub_dir[len(source_dir) + 1 :] + else: + local_dir = None + break + if local_dir is None: + logging.info( + 'Skipping downloading "%s" because no output directory was selected.', + file_name, + ) + continue + if not local_dir: + # The results files which should be output to the first valid output directory + for target_dir in local_dir_map.values(): + if target_dir is not None and target_dir and target_dir.strip(): + local_dir = target_dir + break + file_dir = join(local_dir, sub_dir) + os.makedirs(file_dir, exist_ok=True) + file_path = join(file_dir, file_name) + logging.info('Downloading output file to "%s"...', file_path) + blob.download_to_filename(file_path) + + logging.info('Finished downloading. Output files are in "%s".', local_dir_map) diff --git a/client/google/cloud/bigquery/migration/__init__.py b/client/google/cloud/bigquery/migration/__init__.py index 9c6e1b52e..74bbaae3f 100644 --- a/client/google/cloud/bigquery/migration/__init__.py +++ b/client/google/cloud/bigquery/migration/__init__.py @@ -55,8 +55,11 @@ from google.cloud.bigquery.migration_v2alpha.types.translation_config import RedshiftDialect from google.cloud.bigquery.migration_v2alpha.types.translation_config import SnowflakeDialect from google.cloud.bigquery.migration_v2alpha.types.translation_config import SourceEnv +from google.cloud.bigquery.migration_v2alpha.types.translation_config import SourceLocation +from google.cloud.bigquery.migration_v2alpha.types.translation_config import SourceTargetLocationMapping from google.cloud.bigquery.migration_v2alpha.types.translation_config import SparkSQLDialect from google.cloud.bigquery.migration_v2alpha.types.translation_config import SQLServerDialect +from google.cloud.bigquery.migration_v2alpha.types.translation_config import TargetLocation from google.cloud.bigquery.migration_v2alpha.types.translation_config import TeradataDialect from google.cloud.bigquery.migration_v2alpha.types.translation_config import TranslationConfigDetails from google.cloud.bigquery.migration_v2alpha.types.translation_config import VerticaDialect @@ -108,8 +111,11 @@ 'RedshiftDialect', 'SnowflakeDialect', 'SourceEnv', + 'SourceLocation', + 'SourceTargetLocationMapping', 'SparkSQLDialect', 'SQLServerDialect', + 'TargetLocation', 'TeradataDialect', 'TranslationConfigDetails', 'VerticaDialect', diff --git a/client/google/cloud/bigquery/migration_v2alpha/__init__.py b/client/google/cloud/bigquery/migration_v2alpha/__init__.py index 4ea78e73c..9518275c7 100644 --- a/client/google/cloud/bigquery/migration_v2alpha/__init__.py +++ b/client/google/cloud/bigquery/migration_v2alpha/__init__.py @@ -55,8 +55,11 @@ from .types.translation_config import RedshiftDialect from .types.translation_config import SnowflakeDialect from .types.translation_config import SourceEnv +from .types.translation_config import SourceLocation +from .types.translation_config import SourceTargetLocationMapping from .types.translation_config import SparkSQLDialect from .types.translation_config import SQLServerDialect +from .types.translation_config import TargetLocation from .types.translation_config import TeradataDialect from .types.translation_config import TranslationConfigDetails from .types.translation_config import VerticaDialect @@ -110,8 +113,11 @@ 'SQLServerDialect', 'SnowflakeDialect', 'SourceEnv', +'SourceLocation', +'SourceTargetLocationMapping', 'SparkSQLDialect', 'StartMigrationWorkflowRequest', +'TargetLocation', 'TeradataDialect', 'TeradataOptions', 'TimeInterval', diff --git a/client/google/cloud/bigquery/migration_v2alpha/types/__init__.py b/client/google/cloud/bigquery/migration_v2alpha/types/__init__.py index 8831d134c..87f9a92db 100644 --- a/client/google/cloud/bigquery/migration_v2alpha/types/__init__.py +++ b/client/google/cloud/bigquery/migration_v2alpha/types/__init__.py @@ -62,8 +62,11 @@ RedshiftDialect, SnowflakeDialect, SourceEnv, + SourceLocation, + SourceTargetLocationMapping, SparkSQLDialect, SQLServerDialect, + TargetLocation, TeradataDialect, TranslationConfigDetails, VerticaDialect, @@ -117,8 +120,11 @@ 'RedshiftDialect', 'SnowflakeDialect', 'SourceEnv', + 'SourceLocation', + 'SourceTargetLocationMapping', 'SparkSQLDialect', 'SQLServerDialect', + 'TargetLocation', 'TeradataDialect', 'TranslationConfigDetails', 'VerticaDialect', diff --git a/client/google/cloud/bigquery/migration_v2alpha/types/migration_entities.py b/client/google/cloud/bigquery/migration_v2alpha/types/migration_entities.py index 6384c61b2..636e76a83 100644 --- a/client/google/cloud/bigquery/migration_v2alpha/types/migration_entities.py +++ b/client/google/cloud/bigquery/migration_v2alpha/types/migration_entities.py @@ -138,7 +138,7 @@ class MigrationTask(proto.Message): Translation_Snowflake2BQ, Translation_Netezza2BQ, Translation_AzureSynapse2BQ, Translation_Vertica2BQ, Translation_SQLServer2BQ, Translation_Presto2BQ, - Translation_MySQL2BQ. + Translation_MySQL2BQ, Translation_Postgresql2BQ. details (google.protobuf.any_pb2.Any): DEPRECATED! Use one of the task_details below. The details of the task. The type URL must be one of the supported task diff --git a/client/google/cloud/bigquery/migration_v2alpha/types/translation_config.py b/client/google/cloud/bigquery/migration_v2alpha/types/translation_config.py index 06908bfdb..d960aa144 100644 --- a/client/google/cloud/bigquery/migration_v2alpha/types/translation_config.py +++ b/client/google/cloud/bigquery/migration_v2alpha/types/translation_config.py @@ -40,6 +40,9 @@ 'NameMappingKey', 'NameMappingValue', 'SourceEnv', + 'SourceTargetLocationMapping', + 'SourceLocation', + 'TargetLocation', }, ) @@ -75,6 +78,9 @@ class TranslationConfigDetails(proto.Message): source_env (google.cloud.bigquery.migration_v2alpha.types.SourceEnv): The default source environment values for the translation. + source_target_location_mapping (Sequence[google.cloud.bigquery.migration_v2alpha.types.SourceTargetLocationMapping]): + The mapping from source location paths to + target location paths. """ gcs_source_path = proto.Field( @@ -108,6 +114,11 @@ class TranslationConfigDetails(proto.Message): number=6, message='SourceEnv', ) + source_target_location_mapping = proto.RepeatedField( + proto.MESSAGE, + number=7, + message='SourceTargetLocationMapping', + ) class Dialect(proto.Message): @@ -508,4 +519,69 @@ class SourceEnv(proto.Message): ) +class SourceTargetLocationMapping(proto.Message): + r"""Represents one mapping from a source location path to an + optional target location path. + + Attributes: + source_location (google.cloud.bigquery.migration_v2alpha.types.SourceLocation): + The path to the location of the source data. + target_location (google.cloud.bigquery.migration_v2alpha.types.TargetLocation): + The path to the location of the target data. + """ + + source_location = proto.Field( + proto.MESSAGE, + number=1, + message='SourceLocation', + ) + target_location = proto.Field( + proto.MESSAGE, + number=2, + message='TargetLocation', + ) + + +class SourceLocation(proto.Message): + r"""Represents one path to the location that holds source data. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + gcs_path (str): + The Cloud Storage path for a directory of + files. + + This field is a member of `oneof`_ ``location``. + """ + + gcs_path = proto.Field( + proto.STRING, + number=1, + oneof='location', + ) + + +class TargetLocation(proto.Message): + r"""// Represents one path to the location that holds target + data. + + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + gcs_path (str): + The Cloud Storage path for a directory of + files. + + This field is a member of `oneof`_ ``location``. + """ + + gcs_path = proto.Field( + proto.STRING, + number=1, + oneof='location', + ) + + __all__ = tuple(sorted(__protobuf__.manifest))