Skip to content

Commit

Permalink
Merge pull request #1772 from SFDO-Tooling/feature/smart_get_lookup_k…
Browse files Browse the repository at this point in the history
…ey_field

Feature/smart get lookup key field
  • Loading branch information
David Glick authored Jun 4, 2020
2 parents 646888d + 74d3f80 commit be98883
Show file tree
Hide file tree
Showing 10 changed files with 219 additions and 91 deletions.
5 changes: 2 additions & 3 deletions cumulusci/tasks/bulkdata/base_generate_data_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from sqlalchemy import MetaData
from sqlalchemy.orm import create_session
from sqlalchemy.ext.automap import automap_base
import yaml

from cumulusci.core.tasks import BaseTask
from cumulusci.core.exceptions import TaskOptionsError
from cumulusci.tasks.bulkdata.mapping_parser import parse_from_yaml

from .utils import create_table

Expand Down Expand Up @@ -73,8 +73,7 @@ def _read_mappings(self, mapping_file_path):
if not mapping_file_path:
raise TaskOptionsError("Mapping file path required")

with open(mapping_file_path, "r") as f:
return yaml.safe_load(f)
return parse_from_yaml(mapping_file_path)

@staticmethod
def init_db(db_url, mappings):
Expand Down
35 changes: 18 additions & 17 deletions cumulusci/tasks/bulkdata/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,16 @@
from sqlalchemy.ext.automap import automap_base
import tempfile

import yaml

from cumulusci.core.exceptions import TaskOptionsError, BulkDataException
from cumulusci.tasks.bulkdata.utils import (
SqlAlchemyMixin,
get_lookup_key_field,
create_table,
fields_for_mapping,
)
from cumulusci.tasks.salesforce import BaseSalesforceApiTask
from cumulusci.tasks.bulkdata.step import BulkApiQueryOperation, DataOperationStatus
from cumulusci.utils import os_friendly_path, log_progress
from cumulusci.tasks.bulkdata.mapping_parser import parse_from_yaml


class ExtractData(SqlAlchemyMixin, BaseSalesforceApiTask):
Expand Down Expand Up @@ -58,7 +56,7 @@ def _run_task(self):
self._init_mapping()
self._init_db()

for mapping in self.mappings.values():
for mapping in self.mapping.values():
soql = self._soql_for_mapping(mapping)
self._run_query(soql, mapping)

Expand Down Expand Up @@ -90,8 +88,11 @@ def _init_db(self):

def _init_mapping(self):
"""Load a YAML mapping file."""
with open(self.options["mapping"], "r") as f:
self.mappings = yaml.safe_load(f)
mapping_file_path = self.options["mapping"]
if not mapping_file_path:
raise TaskOptionsError("Mapping file path required")

self.mapping = parse_from_yaml(mapping_file_path)

def _fields_for_mapping(self, mapping):
"""Return a flat list of fields for this mapping."""
Expand Down Expand Up @@ -138,13 +139,13 @@ def _import_results(self, mapping, step):
fields = self._fields_for_mapping(mapping)
columns = []
lookup_keys = []
for sf in fields:
column = mapping.get("fields", {}).get(sf)
for field_name in fields:
column = mapping.get("fields", {}).get(field_name)
if not column:
lookup = mapping.get("lookups", {}).get(sf, {})
lookup = mapping.get("lookups", {}).get(field_name, {})
if lookup:
lookup_keys.append(sf)
column = get_lookup_key_field(lookup, sf)
lookup_keys.append(field_name)
column = lookup.get_lookup_key_field()
if column:
columns.append(column)

Expand Down Expand Up @@ -200,7 +201,7 @@ def _import_results(self, mapping, step):

def _get_mapping_for_table(self, table):
"""Return the first mapping for a table name """
for mapping in self.mappings.values():
for mapping in self.mapping.values():
if mapping["table"] == table:
return mapping

Expand All @@ -219,11 +220,11 @@ def _split_batch_csv(self, records, f_values, f_ids):
def _convert_lookups_to_id(self, mapping, lookup_keys):
"""Rewrite persisted Salesforce Ids to refer to auto-PKs."""
for lookup_key in lookup_keys:
lookup_dict = mapping["lookups"][lookup_key]
lookup_info = mapping["lookups"][lookup_key]
model = self.models[mapping["table"]]
lookup_mapping = self._get_mapping_for_table(lookup_dict["table"])
lookup_mapping = self._get_mapping_for_table(lookup_info["table"])
lookup_model = self.models[lookup_mapping["sf_id_table"]]
key_field = get_lookup_key_field(lookup_dict, lookup_key)
key_field = lookup_info.get_lookup_key_field()
key_attr = getattr(model, key_field)
try:
self.session.query(model).filter(
Expand All @@ -241,7 +242,7 @@ def _convert_lookups_to_id(self, mapping, lookup_keys):

def _create_tables(self):
"""Create a table for each mapping step."""
for mapping in self.mappings.values():
for mapping in self.mapping.values():
self._create_table(mapping)
self.metadata.create_all()

Expand Down Expand Up @@ -279,7 +280,7 @@ def _create_table(self, mapping):

def _drop_sf_id_columns(self):
"""Drop Salesforce Id storage tables after rewriting Ids to auto-PKs."""
for mapping in self.mappings.values():
for mapping in self.mapping.values():
if mapping.get("oid_as_pk"):
continue
self.metadata.tables[mapping["sf_id_table"]].drop()
Expand Down
8 changes: 4 additions & 4 deletions cumulusci/tasks/bulkdata/generate_from_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from cumulusci.core.exceptions import TaskOptionsError
from cumulusci.tasks.bulkdata.base_generate_data_task import BaseGenerateDataTask
from cumulusci.tasks.bulkdata.mapping_parser import parse_from_yaml
from snowfakery.output_streams import SqlOutputStream
from snowfakery.data_generator import generate, StoppingCriteria
from snowfakery.generate_mapping_from_factory import mapping_from_factory_templates
Expand Down Expand Up @@ -83,10 +84,9 @@ def _init_options(self, kwargs):
def _generate_data(self, db_url, mapping_file_path, num_records, current_batch_num):
"""Generate all of the data"""
if mapping_file_path:
with open(mapping_file_path, "r") as f:
self.mappings = yaml.safe_load(f)
self.mapping = parse_from_yaml(mapping_file_path)
else:
self.mappings = {}
self.mapping = {}
self.logger.info(f"Generating batch {current_batch_num} with {num_records}")
self.generate_data(db_url, num_records, current_batch_num)

Expand Down Expand Up @@ -135,7 +135,7 @@ def open_new_continuation_file(self) -> Optional[TextIO]:
return new_continuation_file

def generate_data(self, db_url, num_records, current_batch_num):
output_stream = SqlOutputStream.from_url(db_url, self.mappings)
output_stream = SqlOutputStream.from_url(db_url, self.mapping)
old_continuation_file = self.get_old_continuation_file()
if old_continuation_file:
# reopen to ensure file pointer is at starting point
Expand Down
11 changes: 6 additions & 5 deletions cumulusci/tasks/bulkdata/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from cumulusci.core.exceptions import BulkDataException, TaskOptionsError
from cumulusci.core.utils import process_bool_arg
from cumulusci.tasks.bulkdata.utils import (
get_lookup_key_field,
SqlAlchemyMixin,
RowErrorChecker,
)
Expand Down Expand Up @@ -292,7 +291,7 @@ def _query_db(self, mapping):
for sf_field, lookup in lookups.items():
# Outer join with lookup ids table:
# returns main obj even if lookup is null
key_field = get_lookup_key_field(lookup, sf_field)
key_field = lookup.get_lookup_key_field(model)
value_column = getattr(model, key_field)
query = query.outerjoin(
lookup["aliased_table"],
Expand Down Expand Up @@ -431,9 +430,11 @@ def _init_db(self):

def _init_mapping(self):
"""Load a YAML mapping file."""
with open(self.options["mapping"], "r") as f:
# yaml.safe_load should also work here for now.
self.mapping = parse_from_yaml(f)
mapping_file_path = self.options["mapping"]
if not mapping_file_path:
raise TaskOptionsError("Mapping file path required")

self.mapping = parse_from_yaml(mapping_file_path)

def _expand_mapping(self):
"""Walk the mapping and generate any required 'after' steps
Expand Down
36 changes: 35 additions & 1 deletion cumulusci/tasks/bulkdata/mapping_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
from logging import getLogger
from pathlib import Path

from pydantic import Field, validator, ValidationError
from pydantic import Field, validator, root_validator, ValidationError

from cumulusci.utils.yaml.model_parser import CCIDictModel
from cumulusci.utils import convert_to_snake_case

from typing_extensions import Literal

LOGGER_NAME = "MAPPING_LOADER"
Expand All @@ -19,6 +21,30 @@ class MappingLookup(CCIDictModel):
join_field: Optional[str] = None
after: Optional[str] = None
aliased_table: Optional[str] = None
name: Optional[str] = None # populated by parent

def get_lookup_key_field(self, model=None):
"Find the field name for this lookup."
guesses = []
if self.get("key_field"):
guesses.append(self.get("key_field"))

guesses.append(self.name)

if not model:
return guesses[0]

# CCI used snake_case until mid-2020.
# At some point this code could probably be simplified.
snake_cased_guesses = list(map(convert_to_snake_case, guesses))
guesses = guesses + snake_cased_guesses
for guess in guesses:
if hasattr(model, guess):
return guess
raise KeyError(
f"Could not find a key field for {self.name}.\n"
+ f"Tried {', '.join(guesses)}"
)


class MappingStep(CCIDictModel):
Expand All @@ -35,6 +61,7 @@ class MappingStep(CCIDictModel):
bulk_mode: Optional[
Literal["Serial", "Parallel"]
] = None # default should come from task options
sf_id_table: Optional[str] = None # populated at runtime in extract.py

@validator("record_type")
def record_type_is_deprecated(cls, v):
Expand All @@ -50,6 +77,13 @@ def oid_as_pk_is_deprecated(cls, v):
)
return v

@root_validator # not really a validator, more like a post-processor
def fixup_lookup_names(cls, v):
"Allow lookup objects to know the key they were attached to in the mapping file."
for name, lookup in v["lookups"].items():
lookup.name = name
return v


class MappingSteps(CCIDictModel):
"Mapping of named steps"
Expand Down
Loading

0 comments on commit be98883

Please sign in to comment.