From a6195ff09d2e01c1ba9fbc4172987408756bcf1d Mon Sep 17 00:00:00 2001 From: gabbybarbieri <19822590+gabbybarbieri@users.noreply.github.com> Date: Thu, 25 Aug 2022 12:22:22 -0700 Subject: [PATCH 1/2] Add support for the source/target folder map. --- client/config.yaml | 8 +- .../batch_sql_translator.py | 49 ++++++--- client/dwh_migration_client/config.py | 4 +- client/dwh_migration_client/gcs_util.py | 102 ++++++++++++++++++ 4 files changed, 149 insertions(+), 14 deletions(-) diff --git a/client/config.yaml b/client/config.yaml index 1b2274198..5599fa573 100644 --- a/client/config.yaml +++ b/client/config.yaml @@ -38,4 +38,10 @@ translation_config: # - __DEFAULT_SCHEMA__ # Set this to True (default) to clean up the temporary data in the '.tmp_processed' folder after job finishes. - clean_up_tmp_files: False \ No newline at end of file + clean_up_tmp_files: False + + # [Optional field] The mapping from source file paths to target file paths. + # source_target_location_mapping: { + # "source_path1":"target_path", + # "source_path2":"", + # } \ No newline at end of file diff --git a/client/dwh_migration_client/batch_sql_translator.py b/client/dwh_migration_client/batch_sql_translator.py index 99c0bd51c..74b614146 100644 --- a/client/dwh_migration_client/batch_sql_translator.py +++ b/client/dwh_migration_client/batch_sql_translator.py @@ -21,8 +21,8 @@ import time import uuid from datetime import datetime -from os.path import dirname, join -from typing import Optional +from os.path import dirname, join, abspath +from typing import Optional, Dict, OrderedDict from google.cloud.bigquery import migration_v2alpha as bigquery_migration_v2 @@ -89,20 +89,45 @@ def start_translation(self) -> None: f"gs://{self.config.gcp_settings.gcs_bucket}", gcs_path, "output" ) logging.info("Uploading inputs to gcs ...") - gcs_util.upload_directory( - local_input_dir, - self.config.gcp_settings.gcs_bucket, - join(gcs_path, "input"), - ) + + local_source_target_location_mapping = OrderedDict() + if self.config.translation_config.source_target_location_mapping: + # Use the input/output specified by the user arguments when not default. + if local_input_dir != "client/input": + local_source_target_location_mapping[abspath(local_input_dir)] = local_output_dir + local_source_target_location_mapping.update(self.config.translation_config.source_target_location_mapping) + + if not local_source_target_location_mapping: + gcs_util.upload_directory( + local_input_dir, + self.config.gcp_settings.gcs_bucket, + join(gcs_path, "input"), + ) + else: + # Upload using the map. + gcs_util.upload_full_directories( + local_source_target_location_mapping.keys(), + self.config.gcp_settings.gcs_bucket, + join(gcs_path, "input"), + ) + logging.info("Start translation job...") job_name = self.create_migration_workflow(gcs_input_path, gcs_output_path) self._wait_until_job_finished(job_name) logging.info("Downloading outputs...") - gcs_util.download_directory( - local_output_dir, - self.config.gcp_settings.gcs_bucket, - join(gcs_path, "output"), - ) + if not local_source_target_location_mapping: + gcs_util.download_directory( + local_output_dir, + self.config.gcp_settings.gcs_bucket, + join(gcs_path, "output"), + ) + else: + # Download using the map + gcs_util.download_directories( + local_source_target_location_mapping, + self.config.gcp_settings.gcs_bucket, + join(gcs_path, "output"), + ) if self.preprocessor is not None: logging.info( diff --git a/client/dwh_migration_client/config.py b/client/dwh_migration_client/config.py index 66c9e6642..c495ad72d 100644 --- a/client/dwh_migration_client/config.py +++ b/client/dwh_migration_client/config.py @@ -16,7 +16,7 @@ import logging from dataclasses import asdict, dataclass from pprint import pformat -from typing import List, Optional +from typing import List, Dict, Optional import yaml from marshmallow import Schema, ValidationError, fields, post_load @@ -49,6 +49,7 @@ class TranslationConfig: default_database: Optional[str] schema_search_path: Optional[List[str]] clean_up_tmp_files: bool + source_target_location_mapping: Optional[Dict[str,str]] class TranslationConfigSchema(Schema): @@ -69,6 +70,7 @@ def _deserialize_translation_type(obj: str) -> TranslationType: default_database = fields.String(load_default=None) schema_search_path = fields.List(fields.String(), load_default=None) clean_up_tmp_files = fields.Boolean(load_default=True) + source_target_location_mapping = fields.Dict(keys=fields.String(), values=fields.String, load_default=None) @post_load def build(self, data, **kwargs): # type: ignore[no-untyped-def] # pylint: disable=unused-argument diff --git a/client/dwh_migration_client/gcs_util.py b/client/dwh_migration_client/gcs_util.py index 98199ba94..51e868d8c 100644 --- a/client/dwh_migration_client/gcs_util.py +++ b/client/dwh_migration_client/gcs_util.py @@ -16,6 +16,10 @@ import logging import os from os.path import abspath, basename, isdir, join +from re import sub +from threading import local +from typing import Dict, List, OrderedDict +from pathlib import Path from google.cloud import storage from google.cloud.exceptions import NotFound @@ -56,6 +60,52 @@ def upload_directory(local_dir: str, bucket_name: str, gcs_path: str) -> None: 'Finished uploading input files to gcs "%s/%s".', bucket_name, gcs_path ) +def upload_full_directories(local_dir_list: List[str], bucket_name: str, gcs_path: str) -> None: + """Uploads all the files from local directories to a gcs bucket using the full paths. + + Args: + local_dir_list: paths to the local directories + bucket_name: name of the gcs bucket. If the bucket doesn't exist, the method + tries to create one. + gcs_path: the path to the gcs directory that stores the files. + """ + for local_dir in local_dir_list: + upload_full_directory(local_dir, bucket_name, gcs_path) + +def upload_full_directory(local_dir: str, bucket_name: str, gcs_path: str) -> None: + """Uploads all the files from a local directory to a gcs bucket using the full file path. + + Args: + local_dir: path to the local directory. + bucket_name: name of the gcs bucket. If the bucket doesn't exist, the method + tries to create one. + gcs_path: the path to the gcs directory that stores the files. + """ + assert isdir(local_dir), f"Can't find input directory {local_dir}." + client = storage.Client() + + try: + logging.info("Get bucket %s", bucket_name) + bucket: Bucket = client.get_bucket(bucket_name) + except NotFound: + logging.info('The bucket "%s" does not exist, creating one...', bucket_name) + bucket = client.create_bucket(bucket_name) + + dir_abs_path = abspath(local_dir) + for root, _, files in os.walk(dir_abs_path): + for name in files: + sub_dir = root[len(dir_abs_path) :] + if sub_dir.startswith("/"): + sub_dir = sub_dir[1:] + file_path = join(root, name) + logging.info('Uploading file "%s" to gcs...', file_path) + # gcs_file_path = join(gcs_path, sub_dir, name) + gcs_file_path = join(join(gcs_path, file_path[1:])) + blob = bucket.blob(gcs_file_path) + blob.upload_from_filename(file_path) + logging.info( + 'Finished uploading input files to gcs "%s/%s".', bucket_name, gcs_path + ) def download_directory(local_dir: str, bucket_name: str, gcs_path: str) -> None: """Download all the files from a gcs bucket to a local directory. @@ -79,3 +129,55 @@ def download_directory(local_dir: str, bucket_name: str, gcs_path: str) -> None: blob.download_to_filename(file_path) logging.info('Finished downloading. Output files are in "%s".', local_dir) + +def download_directories(local_dir_map: OrderedDict[str,str], bucket_name: str, gcs_path: str) -> None: + """Download all the files from specific directories in a gcs bucket to local directories. + + Args: + local_dir_map: paths from the input locations to the local directories to store the downloaded files. It will + create the directory if it doesn't exist. + bucket_name: name of the gcs bucket. + gcs_path: the path to the gcs directory that stores the files. + """ + client = storage.Client() + blobs = client.list_blobs(bucket_name, prefix=gcs_path) + logging.info('Start downloading outputs from gcs "%s/%s"', bucket_name, gcs_path) + for blob in blobs: + file_name = basename(blob.name) + sub_dir = blob.name[len(gcs_path) + 1 : -len(file_name)] + # determine local_dir based on what dir belongs to + local_dir = "" + for source_dir in local_dir_map.keys(): + # source_sub_dir = source_dir[-len(sub_dir) + 1:] + # print("source sub directory: " + source_sub_dir) + local_target = local_dir_map[source_dir] + if Path("/"+ sub_dir).is_relative_to(source_dir): + if local_target is not None and local_target and local_target.strip(): + local_dir = local_target + # Clean up sub_dir so it is no longer with an abs path + if source_dir.startswith("/"): + source_dir = source_dir[1:] + if source_dir.endswith("/"): + source_dir = source_dir[:-1] + sub_dir = sub_dir[len(source_dir) + 1:] + else: + local_dir = None + break + if local_dir is None: + logging.info('Skipping downloading "%s" because no output directory was selected.', file_name) + continue + if not local_dir: + # The results files which should be output to the first valid output directory + for target_dir in local_dir_map.values(): + if target_dir is not None and target_dir and target_dir.strip(): + local_dir = target_dir + break + file_dir = join(local_dir, sub_dir) + os.makedirs(file_dir, exist_ok=True) + file_path = join(file_dir, file_name) + logging.info('Downloading output file to "%s"...', file_path) + blob.download_to_filename(file_path) + + logging.info('Finished downloading. Output files are in "%s".', local_dir_map) + + # Only want to output files that aren't in sub dir to the first directory in the list \ No newline at end of file From 1e679f7958882a35fd583dabcf4c9359a8361a4f Mon Sep 17 00:00:00 2001 From: gabbybarbieri <19822590+gabbybarbieri@users.noreply.github.com> Date: Thu, 25 Aug 2022 12:22:22 -0700 Subject: [PATCH 2/2] Add support for the source/target folder map. --- client/config.yaml | 8 +- .../batch_sql_translator.py | 53 +++++++-- client/dwh_migration_client/config.py | 6 +- client/dwh_migration_client/gcs_util.py | 106 ++++++++++++++++++ 4 files changed, 159 insertions(+), 14 deletions(-) diff --git a/client/config.yaml b/client/config.yaml index 1b2274198..2629f3717 100644 --- a/client/config.yaml +++ b/client/config.yaml @@ -38,4 +38,10 @@ translation_config: # - __DEFAULT_SCHEMA__ # Set this to True (default) to clean up the temporary data in the '.tmp_processed' folder after job finishes. - clean_up_tmp_files: False \ No newline at end of file + clean_up_tmp_files: False + + # [Optional field] The mapping from source file paths to optional target file paths. + # source_target_location_mapping: { + # "source_path1":"target_path", + # "source_path2":"", + # } \ No newline at end of file diff --git a/client/dwh_migration_client/batch_sql_translator.py b/client/dwh_migration_client/batch_sql_translator.py index 99c0bd51c..eeb95f98c 100644 --- a/client/dwh_migration_client/batch_sql_translator.py +++ b/client/dwh_migration_client/batch_sql_translator.py @@ -21,8 +21,8 @@ import time import uuid from datetime import datetime -from os.path import dirname, join -from typing import Optional +from os.path import dirname, join, abspath +from typing import Optional, OrderedDict from google.cloud.bigquery import migration_v2alpha as bigquery_migration_v2 @@ -89,20 +89,49 @@ def start_translation(self) -> None: f"gs://{self.config.gcp_settings.gcs_bucket}", gcs_path, "output" ) logging.info("Uploading inputs to gcs ...") - gcs_util.upload_directory( - local_input_dir, - self.config.gcp_settings.gcs_bucket, - join(gcs_path, "input"), - ) + + local_source_target_location_mapping = OrderedDict() + if self.config.translation_config.source_target_location_mapping: + # Use the input/output specified by the user arguments when not default. + if local_input_dir != "client/input": + local_source_target_location_mapping[ + abspath(local_input_dir) + ] = local_output_dir + local_source_target_location_mapping.update( + self.config.translation_config.source_target_location_mapping + ) + + if not local_source_target_location_mapping: + gcs_util.upload_directory( + local_input_dir, + self.config.gcp_settings.gcs_bucket, + join(gcs_path, "input"), + ) + else: + # Upload using the map. + gcs_util.upload_full_directories( + local_source_target_location_mapping.keys(), + self.config.gcp_settings.gcs_bucket, + join(gcs_path, "input"), + ) + logging.info("Start translation job...") job_name = self.create_migration_workflow(gcs_input_path, gcs_output_path) self._wait_until_job_finished(job_name) logging.info("Downloading outputs...") - gcs_util.download_directory( - local_output_dir, - self.config.gcp_settings.gcs_bucket, - join(gcs_path, "output"), - ) + if not local_source_target_location_mapping: + gcs_util.download_directory( + local_output_dir, + self.config.gcp_settings.gcs_bucket, + join(gcs_path, "output"), + ) + else: + # Download using the map + gcs_util.download_directories( + local_source_target_location_mapping, + self.config.gcp_settings.gcs_bucket, + join(gcs_path, "output"), + ) if self.preprocessor is not None: logging.info( diff --git a/client/dwh_migration_client/config.py b/client/dwh_migration_client/config.py index 66c9e6642..62311ec05 100644 --- a/client/dwh_migration_client/config.py +++ b/client/dwh_migration_client/config.py @@ -16,7 +16,7 @@ import logging from dataclasses import asdict, dataclass from pprint import pformat -from typing import List, Optional +from typing import List, Dict, Optional import yaml from marshmallow import Schema, ValidationError, fields, post_load @@ -49,6 +49,7 @@ class TranslationConfig: default_database: Optional[str] schema_search_path: Optional[List[str]] clean_up_tmp_files: bool + source_target_location_mapping: Optional[Dict[str, str]] class TranslationConfigSchema(Schema): @@ -69,6 +70,9 @@ def _deserialize_translation_type(obj: str) -> TranslationType: default_database = fields.String(load_default=None) schema_search_path = fields.List(fields.String(), load_default=None) clean_up_tmp_files = fields.Boolean(load_default=True) + source_target_location_mapping = fields.Dict( + keys=fields.String(), values=fields.String, load_default=None + ) @post_load def build(self, data, **kwargs): # type: ignore[no-untyped-def] # pylint: disable=unused-argument diff --git a/client/dwh_migration_client/gcs_util.py b/client/dwh_migration_client/gcs_util.py index 98199ba94..b113501ac 100644 --- a/client/dwh_migration_client/gcs_util.py +++ b/client/dwh_migration_client/gcs_util.py @@ -16,6 +16,8 @@ import logging import os from os.path import abspath, basename, isdir, join +from typing import List, OrderedDict +from pathlib import Path from google.cloud import storage from google.cloud.exceptions import NotFound @@ -57,6 +59,56 @@ def upload_directory(local_dir: str, bucket_name: str, gcs_path: str) -> None: ) +def upload_full_directories( + local_dir_list: List[str], bucket_name: str, gcs_path: str +) -> None: + """Uploads all the files from local directories to a gcs bucket using the full paths. + + Args: + local_dir_list: paths to the local directories + bucket_name: name of the gcs bucket. If the bucket doesn't exist, the method + tries to create one. + gcs_path: the path to the gcs directory that stores the files. + """ + for local_dir in local_dir_list: + upload_full_directory(local_dir, bucket_name, gcs_path) + + +def upload_full_directory(local_dir: str, bucket_name: str, gcs_path: str) -> None: + """Uploads all the files from a local directory to a gcs bucket using the full/absolute file path. + + Args: + local_dir: path to the local directory. + bucket_name: name of the gcs bucket. If the bucket doesn't exist, the method + tries to create one. + gcs_path: the path to the gcs directory that stores the files. + """ + assert isdir(local_dir), f"Can't find input directory {local_dir}." + client = storage.Client() + + try: + logging.info("Get bucket %s", bucket_name) + bucket: Bucket = client.get_bucket(bucket_name) + except NotFound: + logging.info('The bucket "%s" does not exist, creating one...', bucket_name) + bucket = client.create_bucket(bucket_name) + + dir_abs_path = abspath(local_dir) + for root, _, files in os.walk(dir_abs_path): + for name in files: + sub_dir = root[len(dir_abs_path) :] + if sub_dir.startswith("/"): + sub_dir = sub_dir[1:] + file_path = join(root, name) + logging.info('Uploading file "%s" to gcs...', file_path) + gcs_file_path = join(join(gcs_path, file_path[1:])) + blob = bucket.blob(gcs_file_path) + blob.upload_from_filename(file_path) + logging.info( + 'Finished uploading input files to gcs "%s/%s".', bucket_name, gcs_path + ) + + def download_directory(local_dir: str, bucket_name: str, gcs_path: str) -> None: """Download all the files from a gcs bucket to a local directory. @@ -79,3 +131,57 @@ def download_directory(local_dir: str, bucket_name: str, gcs_path: str) -> None: blob.download_to_filename(file_path) logging.info('Finished downloading. Output files are in "%s".', local_dir) + + +def download_directories( + local_dir_map: OrderedDict[str, str], bucket_name: str, gcs_path: str +) -> None: + """Download all the files from specific directories in a gcs bucket to local directories. + + Args: + local_dir_map: paths from the input locations to the local directories to store the downloaded files. It will + create the directory if it doesn't exist. + bucket_name: name of the gcs bucket. + gcs_path: the path to the gcs directory that stores the files. + """ + client = storage.Client() + blobs = client.list_blobs(bucket_name, prefix=gcs_path) + logging.info('Start downloading outputs from gcs "%s/%s"', bucket_name, gcs_path) + for blob in blobs: + file_name = basename(blob.name) + sub_dir = blob.name[len(gcs_path) + 1 : -len(file_name)] + # Determine local_dir based on what source dir it belongs to + local_dir = "" + for source_dir in local_dir_map.keys(): + local_target = local_dir_map[source_dir] + if Path("/" + sub_dir).is_relative_to(source_dir): + if local_target is not None and local_target and local_target.strip(): + local_dir = local_target + # Clean up sub_dir to no longer be an abs path so it will join properly for the output + if source_dir.startswith("/"): + source_dir = source_dir[1:] + if source_dir.endswith("/"): + source_dir = source_dir[:-1] + sub_dir = sub_dir[len(source_dir) + 1 :] + else: + local_dir = None + break + if local_dir is None: + logging.info( + 'Skipping downloading "%s" because no output directory was selected.', + file_name, + ) + continue + if not local_dir: + # The results files which should be output to the first valid output directory + for target_dir in local_dir_map.values(): + if target_dir is not None and target_dir and target_dir.strip(): + local_dir = target_dir + break + file_dir = join(local_dir, sub_dir) + os.makedirs(file_dir, exist_ok=True) + file_path = join(file_dir, file_name) + logging.info('Downloading output file to "%s"...', file_path) + blob.download_to_filename(file_path) + + logging.info('Finished downloading. Output files are in "%s".', local_dir_map)