Skip to content

Commit

Permalink
Reference parent level record during similarity matching
Browse files Browse the repository at this point in the history
  • Loading branch information
aditya-balachander committed Nov 7, 2024
1 parent 40097c1 commit 83d45db
Show file tree
Hide file tree
Showing 5 changed files with 350 additions and 75 deletions.
117 changes: 114 additions & 3 deletions cumulusci/tasks/bulkdata/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
AddMappingFiltersToQuery,
AddPersonAccountsToQuery,
AddRecordTypesToQuery,
DynamicLookupQueryExtender,
)
from cumulusci.tasks.bulkdata.step import (
DEFAULT_BULK_BATCH_SIZE,
Expand Down Expand Up @@ -314,6 +315,7 @@ def configure_step(self, mapping):
bulk_mode = mapping.bulk_mode or self.bulk_mode or "Parallel"
api_options = {"batch_size": mapping.batch_size, "bulk_mode": bulk_mode}
num_records_in_target = None
content_type = None

fields = mapping.get_load_field_list()

Expand Down Expand Up @@ -343,6 +345,8 @@ def configure_step(self, mapping):
api_options["update_key"] = mapping.update_key[0]
action = DataOperationType.UPSERT
elif mapping.action == DataOperationType.SELECT:
# Set content type to json
content_type = "JSON"
# Bulk process expects DataOpertionType to be QUERY
action = DataOperationType.QUERY
# Determine number of records in the target org
Expand All @@ -354,6 +358,97 @@ def configure_step(self, mapping):
for entry in record_count_response["sObjects"]
}
num_records_in_target = sobject_map.get(mapping.sf_object, None)

# Check for similarity selection strategy and modify fields accordingly
if mapping.selection_strategy == "similarity":
# Describe the object to determine polymorphic lookups
describe_result = self.sf.restful(
f"sobjects/{mapping.sf_object}/describe"
)
polymorphic_fields = {
field["name"]: field
for field in describe_result["fields"]
if field["type"] == "reference"
}

# Loop through each lookup to get the corresponding fields
for name, lookup in mapping.lookups.items():
if name in fields:
# Get the index of the lookup field before removing it
insert_index = fields.index(name)
# Remove the lookup field from fields
fields.remove(name)

# Check if this lookup field is polymorphic
if (
name in polymorphic_fields
and len(polymorphic_fields[name]["referenceTo"]) > 1
):
# Convert to list if string
if not isinstance(lookup.table, list):
lookup.table = [lookup.table]
# Polymorphic field handling
polymorphic_references = lookup.table
relationship_name = polymorphic_fields[name][
"relationshipName"
]

# Loop through each polymorphic type (e.g., Contact, Lead)
for ref_type in polymorphic_references:
# Find the mapping step for this polymorphic type
lookup_mapping_step = next(
(
step
for step in self.mapping.values()
if step.sf_object == ref_type
),
None,
)

if lookup_mapping_step:
lookup_fields = (
lookup_mapping_step.get_load_field_list()
)
# Insert fields in the format {relationship_name}.{ref_type}.{lookup_field}
for field in lookup_fields:
fields.insert(
insert_index,
f"{relationship_name}.{lookup_mapping_step.sf_object}.{field}",
)
insert_index += 1

else:
# Non-polymorphic field handling
lookup_table = lookup.table

if isinstance(lookup_table, list):
lookup_table = lookup_table[0]

# Get the mapping step for the non-polymorphic reference
lookup_mapping_step = next(
(
step
for step in self.mapping.values()
if step.sf_object == lookup_table
),
None,
)

if lookup_mapping_step:
relationship_name = polymorphic_fields[name][
"relationshipName"
]
lookup_fields = (
lookup_mapping_step.get_load_field_list()
)

# Insert the new fields at the same position as the removed lookup field
for field in lookup_fields:
fields.insert(
insert_index, f"{relationship_name}.{field}"
)
insert_index += 1

else:
action = mapping.action

Expand All @@ -376,6 +471,7 @@ def configure_step(self, mapping):
volume=volume,
selection_strategy=mapping.selection_strategy,
selection_filter=mapping.selection_filter,
content_type=content_type,
)
return step, query

Expand Down Expand Up @@ -406,6 +502,9 @@ def _stream_queried_data(self, mapping, local_ids, query):
pkey = row[0]
row = list(row[1:]) + statics

# Replace None values in row with empty strings
row = [value if value is not None else "" for value in row]

if mapping.anchor_date and (date_context[0] or date_context[1]):
row = adjust_relative_dates(
mapping, date_context, row, DataOperationType.INSERT
Expand Down Expand Up @@ -475,9 +574,21 @@ def _query_db(self, mapping):
AddMappingFiltersToQuery,
AddUpsertsToQuery,
]
transformers = [
AddLookupsToQuery(mapping, self.metadata, model, self._old_format)
]
transformers = []
if (
mapping.action == DataOperationType.SELECT
and mapping.selection_strategy == "similarity"
):
transformers.append(
DynamicLookupQueryExtender(
mapping, self.mapping, self.metadata, model, self._old_format
)
)
else:
transformers.append(
AddLookupsToQuery(mapping, self.metadata, model, self._old_format)
)

transformers.extend([cls(mapping, self.metadata, model) for cls in classes])

if mapping.sf_object == "Contact" and self._can_load_person_accounts(mapping):
Expand Down
18 changes: 10 additions & 8 deletions cumulusci/tasks/bulkdata/mapping_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,15 @@ class MappingStep(CCIDictModel):
batch_size: int = None
oid_as_pk: bool = False # this one should be discussed and probably deprecated
record_type: Optional[str] = None # should be discussed and probably deprecated
bulk_mode: Optional[
Literal["Serial", "Parallel"]
] = None # default should come from task options
bulk_mode: Optional[Literal["Serial", "Parallel"]] = (
None # default should come from task options
)
anchor_date: Optional[Union[str, date]] = None
soql_filter: Optional[str] = None # soql_filter property
selection_strategy: SelectStrategy = SelectStrategy.STANDARD # selection strategy
selection_filter: Optional[
str
] = None # filter to be added at the end of select query
selection_filter: Optional[str] = (
None # filter to be added at the end of select query
)
update_key: T.Union[str, T.Tuple[str, ...]] = () # only for upserts

@validator("bulk_mode", "api", "action", "selection_strategy", pre=True)
Expand Down Expand Up @@ -678,7 +678,9 @@ def _infer_and_validate_lookups(mapping: Dict, sf: Salesforce):
if len(target_objects) == 1:
# This is a non-polymorphic lookup.
target_index = list(sf_objects.values()).index(target_objects[0])
if target_index > idx or target_index == idx:
if (
target_index > idx or target_index == idx
) and m.action != DataOperationType.SELECT:
# This is a non-polymorphic after step.
lookup.after = list(mapping.keys())[idx]
else:
Expand Down Expand Up @@ -730,7 +732,7 @@ def validate_and_inject_mapping(

if drop_missing:
# Drop any steps with sObjects that are not present.
for (include, step_name) in zip(should_continue, list(mapping.keys())):
for include, step_name in zip(should_continue, list(mapping.keys())):
if not include:
del mapping[step_name]

Expand Down
60 changes: 60 additions & 0 deletions cumulusci/tasks/bulkdata/query_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,66 @@ def join_for_lookup(lookup):
return [join_for_lookup(lookup) for lookup in self.lookups]


class DynamicLookupQueryExtender(LoadQueryExtender):
"""Dynamically adds columns and joins for all fields in lookup tables, handling polymorphic lookups"""

def __init__(
self, mapping, all_mappings, metadata, model, _old_format: bool
) -> None:
super().__init__(mapping, metadata, model)
self._old_format = _old_format
self.all_mappings = all_mappings
self.lookups = [
lookup for lookup in self.mapping.lookups.values() if not lookup.after
]

@cached_property
def columns_to_add(self):
"""Add all relevant fields from lookup tables directly without CASE, with support for polymorphic lookups."""
columns = []
for lookup in self.lookups:
tables = lookup.table if isinstance(lookup.table, list) else [lookup.table]
lookup.aliased_table = [
aliased(self.metadata.tables[table]) for table in tables
]

for aliased_table, table_name in zip(lookup.aliased_table, tables):
# Find the mapping step for this polymorphic type
lookup_mapping_step = next(
(
step
for step in self.all_mappings.values()
if step.table == table_name
),
None,
)
if lookup_mapping_step:
load_fields = lookup_mapping_step.get_load_field_list()
for field in load_fields:
matching_column = next(
(col for col in aliased_table.columns if col.name == field)
)
columns.append(
matching_column.label(f"{aliased_table.name}_{field}")
)
return columns

@cached_property
def outerjoins_to_add(self):
"""Add outer joins for each lookup table directly, including handling for polymorphic lookups."""

def join_for_lookup(lookup, aliased_table):
key_field = lookup.get_lookup_key_field(self.model)
value_column = getattr(self.model, key_field)
return (aliased_table, aliased_table.columns.id == value_column)

joins = []
for lookup in self.lookups:
for aliased_table in lookup.aliased_table:
joins.append(join_for_lookup(lookup, aliased_table))
return joins


class AddRecordTypesToQuery(LoadQueryExtender):
"""Adds columns, joins and filters relatinng to recordtypes"""

Expand Down
Loading

0 comments on commit 83d45db

Please sign in to comment.