From 83d45db946a512bbaabd094e9199d2df9d870bd0 Mon Sep 17 00:00:00 2001 From: aditya-balachander Date: Thu, 7 Nov 2024 11:23:37 +0530 Subject: [PATCH] Reference parent level record during similarity matching --- cumulusci/tasks/bulkdata/load.py | 117 +++++++++++++- cumulusci/tasks/bulkdata/mapping_parser.py | 18 ++- .../tasks/bulkdata/query_transformers.py | 60 ++++++++ cumulusci/tasks/bulkdata/select_utils.py | 87 ++++++----- cumulusci/tasks/bulkdata/step.py | 143 ++++++++++++++---- 5 files changed, 350 insertions(+), 75 deletions(-) diff --git a/cumulusci/tasks/bulkdata/load.py b/cumulusci/tasks/bulkdata/load.py index f83199050a..d4050c0aca 100644 --- a/cumulusci/tasks/bulkdata/load.py +++ b/cumulusci/tasks/bulkdata/load.py @@ -27,6 +27,7 @@ AddMappingFiltersToQuery, AddPersonAccountsToQuery, AddRecordTypesToQuery, + DynamicLookupQueryExtender, ) from cumulusci.tasks.bulkdata.step import ( DEFAULT_BULK_BATCH_SIZE, @@ -314,6 +315,7 @@ def configure_step(self, mapping): bulk_mode = mapping.bulk_mode or self.bulk_mode or "Parallel" api_options = {"batch_size": mapping.batch_size, "bulk_mode": bulk_mode} num_records_in_target = None + content_type = None fields = mapping.get_load_field_list() @@ -343,6 +345,8 @@ def configure_step(self, mapping): api_options["update_key"] = mapping.update_key[0] action = DataOperationType.UPSERT elif mapping.action == DataOperationType.SELECT: + # Set content type to json + content_type = "JSON" # Bulk process expects DataOpertionType to be QUERY action = DataOperationType.QUERY # Determine number of records in the target org @@ -354,6 +358,97 @@ def configure_step(self, mapping): for entry in record_count_response["sObjects"] } num_records_in_target = sobject_map.get(mapping.sf_object, None) + + # Check for similarity selection strategy and modify fields accordingly + if mapping.selection_strategy == "similarity": + # Describe the object to determine polymorphic lookups + describe_result = self.sf.restful( + f"sobjects/{mapping.sf_object}/describe" + ) + polymorphic_fields = { + field["name"]: field + for field in describe_result["fields"] + if field["type"] == "reference" + } + + # Loop through each lookup to get the corresponding fields + for name, lookup in mapping.lookups.items(): + if name in fields: + # Get the index of the lookup field before removing it + insert_index = fields.index(name) + # Remove the lookup field from fields + fields.remove(name) + + # Check if this lookup field is polymorphic + if ( + name in polymorphic_fields + and len(polymorphic_fields[name]["referenceTo"]) > 1 + ): + # Convert to list if string + if not isinstance(lookup.table, list): + lookup.table = [lookup.table] + # Polymorphic field handling + polymorphic_references = lookup.table + relationship_name = polymorphic_fields[name][ + "relationshipName" + ] + + # Loop through each polymorphic type (e.g., Contact, Lead) + for ref_type in polymorphic_references: + # Find the mapping step for this polymorphic type + lookup_mapping_step = next( + ( + step + for step in self.mapping.values() + if step.sf_object == ref_type + ), + None, + ) + + if lookup_mapping_step: + lookup_fields = ( + lookup_mapping_step.get_load_field_list() + ) + # Insert fields in the format {relationship_name}.{ref_type}.{lookup_field} + for field in lookup_fields: + fields.insert( + insert_index, + f"{relationship_name}.{lookup_mapping_step.sf_object}.{field}", + ) + insert_index += 1 + + else: + # Non-polymorphic field handling + lookup_table = lookup.table + + if isinstance(lookup_table, list): + lookup_table = lookup_table[0] + + # Get the mapping step for the non-polymorphic reference + lookup_mapping_step = next( + ( + step + for step in self.mapping.values() + if step.sf_object == lookup_table + ), + None, + ) + + if lookup_mapping_step: + relationship_name = polymorphic_fields[name][ + "relationshipName" + ] + lookup_fields = ( + lookup_mapping_step.get_load_field_list() + ) + + # Insert the new fields at the same position as the removed lookup field + for field in lookup_fields: + fields.insert( + insert_index, f"{relationship_name}.{field}" + ) + insert_index += 1 + else: action = mapping.action @@ -376,6 +471,7 @@ def configure_step(self, mapping): volume=volume, selection_strategy=mapping.selection_strategy, selection_filter=mapping.selection_filter, + content_type=content_type, ) return step, query @@ -406,6 +502,9 @@ def _stream_queried_data(self, mapping, local_ids, query): pkey = row[0] row = list(row[1:]) + statics + # Replace None values in row with empty strings + row = [value if value is not None else "" for value in row] + if mapping.anchor_date and (date_context[0] or date_context[1]): row = adjust_relative_dates( mapping, date_context, row, DataOperationType.INSERT @@ -475,9 +574,21 @@ def _query_db(self, mapping): AddMappingFiltersToQuery, AddUpsertsToQuery, ] - transformers = [ - AddLookupsToQuery(mapping, self.metadata, model, self._old_format) - ] + transformers = [] + if ( + mapping.action == DataOperationType.SELECT + and mapping.selection_strategy == "similarity" + ): + transformers.append( + DynamicLookupQueryExtender( + mapping, self.mapping, self.metadata, model, self._old_format + ) + ) + else: + transformers.append( + AddLookupsToQuery(mapping, self.metadata, model, self._old_format) + ) + transformers.extend([cls(mapping, self.metadata, model) for cls in classes]) if mapping.sf_object == "Contact" and self._can_load_person_accounts(mapping): diff --git a/cumulusci/tasks/bulkdata/mapping_parser.py b/cumulusci/tasks/bulkdata/mapping_parser.py index e812ca7d16..c9009f82fc 100644 --- a/cumulusci/tasks/bulkdata/mapping_parser.py +++ b/cumulusci/tasks/bulkdata/mapping_parser.py @@ -103,15 +103,15 @@ class MappingStep(CCIDictModel): batch_size: int = None oid_as_pk: bool = False # this one should be discussed and probably deprecated record_type: Optional[str] = None # should be discussed and probably deprecated - bulk_mode: Optional[ - Literal["Serial", "Parallel"] - ] = None # default should come from task options + bulk_mode: Optional[Literal["Serial", "Parallel"]] = ( + None # default should come from task options + ) anchor_date: Optional[Union[str, date]] = None soql_filter: Optional[str] = None # soql_filter property selection_strategy: SelectStrategy = SelectStrategy.STANDARD # selection strategy - selection_filter: Optional[ - str - ] = None # filter to be added at the end of select query + selection_filter: Optional[str] = ( + None # filter to be added at the end of select query + ) update_key: T.Union[str, T.Tuple[str, ...]] = () # only for upserts @validator("bulk_mode", "api", "action", "selection_strategy", pre=True) @@ -678,7 +678,9 @@ def _infer_and_validate_lookups(mapping: Dict, sf: Salesforce): if len(target_objects) == 1: # This is a non-polymorphic lookup. target_index = list(sf_objects.values()).index(target_objects[0]) - if target_index > idx or target_index == idx: + if ( + target_index > idx or target_index == idx + ) and m.action != DataOperationType.SELECT: # This is a non-polymorphic after step. lookup.after = list(mapping.keys())[idx] else: @@ -730,7 +732,7 @@ def validate_and_inject_mapping( if drop_missing: # Drop any steps with sObjects that are not present. - for (include, step_name) in zip(should_continue, list(mapping.keys())): + for include, step_name in zip(should_continue, list(mapping.keys())): if not include: del mapping[step_name] diff --git a/cumulusci/tasks/bulkdata/query_transformers.py b/cumulusci/tasks/bulkdata/query_transformers.py index aef23f5dc3..eda7a2cabe 100644 --- a/cumulusci/tasks/bulkdata/query_transformers.py +++ b/cumulusci/tasks/bulkdata/query_transformers.py @@ -86,6 +86,66 @@ def join_for_lookup(lookup): return [join_for_lookup(lookup) for lookup in self.lookups] +class DynamicLookupQueryExtender(LoadQueryExtender): + """Dynamically adds columns and joins for all fields in lookup tables, handling polymorphic lookups""" + + def __init__( + self, mapping, all_mappings, metadata, model, _old_format: bool + ) -> None: + super().__init__(mapping, metadata, model) + self._old_format = _old_format + self.all_mappings = all_mappings + self.lookups = [ + lookup for lookup in self.mapping.lookups.values() if not lookup.after + ] + + @cached_property + def columns_to_add(self): + """Add all relevant fields from lookup tables directly without CASE, with support for polymorphic lookups.""" + columns = [] + for lookup in self.lookups: + tables = lookup.table if isinstance(lookup.table, list) else [lookup.table] + lookup.aliased_table = [ + aliased(self.metadata.tables[table]) for table in tables + ] + + for aliased_table, table_name in zip(lookup.aliased_table, tables): + # Find the mapping step for this polymorphic type + lookup_mapping_step = next( + ( + step + for step in self.all_mappings.values() + if step.table == table_name + ), + None, + ) + if lookup_mapping_step: + load_fields = lookup_mapping_step.get_load_field_list() + for field in load_fields: + matching_column = next( + (col for col in aliased_table.columns if col.name == field) + ) + columns.append( + matching_column.label(f"{aliased_table.name}_{field}") + ) + return columns + + @cached_property + def outerjoins_to_add(self): + """Add outer joins for each lookup table directly, including handling for polymorphic lookups.""" + + def join_for_lookup(lookup, aliased_table): + key_field = lookup.get_lookup_key_field(self.model) + value_column = getattr(self.model, key_field) + return (aliased_table, aliased_table.columns.id == value_column) + + joins = [] + for lookup in self.lookups: + for aliased_table in lookup.aliased_table: + joins.append(join_for_lookup(lookup, aliased_table)) + return joins + + class AddRecordTypesToQuery(LoadQueryExtender): """Adds columns, joins and filters relatinng to recordtypes""" diff --git a/cumulusci/tasks/bulkdata/select_utils.py b/cumulusci/tasks/bulkdata/select_utils.py index 741ed17056..d1092504f4 100644 --- a/cumulusci/tasks/bulkdata/select_utils.py +++ b/cumulusci/tasks/bulkdata/select_utils.py @@ -142,14 +142,54 @@ def similarity_generate_query( limit: T.Union[int, None], offset: T.Union[int, None], ) -> T.Tuple[str, T.List[str]]: - """Generates the SOQL query for the similarity selection strategy""" - # Construct the query with the WHERE clause (if it exists) - if "Id" not in fields: - fields.insert(0, "Id") - fields_to_query = ", ".join(field for field in fields if field) - + """Generates the SOQL query for the similarity selection strategy, with support for TYPEOF on polymorphic fields.""" + + # Pre-process the new fields format to create a nested dict structure for TYPEOF clauses + nested_fields = {} + regular_fields = [] + + for field in fields: + components = field.split(".") + if len(components) >= 3: + # Handle polymorphic fields (format: {relationship_name}.{ref_obj}.{ref_field}) + relationship, ref_obj, ref_field = ( + components[0], + components[1], + components[2], + ) + if relationship not in nested_fields: + nested_fields[relationship] = {} + if ref_obj not in nested_fields[relationship]: + nested_fields[relationship][ref_obj] = [] + nested_fields[relationship][ref_obj].append(ref_field) + else: + # Handle regular fields (format: {field}) + regular_fields.append(field) + + # Construct the query fields + query_fields = [] + + # Build TYPEOF clauses for polymorphic fields + for relationship, references in nested_fields.items(): + type_clauses = [] + for ref_obj, ref_fields in references.items(): + fields_clause = ", ".join(ref_fields) + type_clauses.append(f"WHEN {ref_obj} THEN {fields_clause}") + type_clause = f"TYPEOF {relationship} {' '.join(type_clauses)} END" + query_fields.append(type_clause) + + # Add regular fields to the query + query_fields.extend(regular_fields) + + # Ensure "Id" is included in the fields list for identification + if "Id" not in query_fields: + query_fields.insert(0, "Id") + + # Build the main SOQL query + fields_to_query = ", ".join(query_fields) query = f"SELECT {fields_to_query} FROM {sobject}" + # Add the user-defined filter clause or default clause if user_filter: query += add_limit_offset_to_user_filter( filter_clause=user_filter, limit_clause=limit, offset_clause=offset @@ -161,7 +201,12 @@ def similarity_generate_query( query += f" WHERE {declaration.where}" query += f" LIMIT {limit}" if limit else "" query += f" OFFSET {offset}" if offset else "" - return query, fields + + # Return the original input fields with "Id" added if needed + if "Id" not in fields: + fields.insert(0, "Id") + + return query, fields # Return the original input fields with "Id" def similarity_post_process( @@ -178,8 +223,6 @@ def similarity_post_process( complexity_constant = load_record_count * query_record_count - print(complexity_constant) - closest_records = [] if complexity_constant < 1000: @@ -187,8 +230,6 @@ def similarity_post_process( else: closest_records = levenshtein_post_process(load_records, query_records) - print(closest_records) - return closest_records @@ -200,14 +241,6 @@ def annoy_post_process( query_records = replace_empty_strings_with_missing(query_records) load_records = replace_empty_strings_with_missing(load_records) - print("Query records: ") - print(query_records) - - print("Load records: ") - print(load_records) - - print("\n\n\n\n") - hash_features = 100 num_trees = 10 @@ -244,29 +277,15 @@ def annoy_post_process( load_vector, n_neighbors, include_distances=True ) neighbor_indices = nearest_neighbors[0] # Indices of nearest neighbors - distances = nearest_neighbors[1] # Distances to nearest neighbors - load_record = load_records[i] # Get the query record for the current index - print(f"Load record {i + 1}: {load_record}\n") # Print the query record - - # Print the nearest neighbors for the current query - print(f"Nearest neighbors for load record {i + 1}:") - - for j, neighbor_index in enumerate(neighbor_indices): + for neighbor_index in neighbor_indices: # Retrieve the corresponding record from the database record = query_record_data[neighbor_index] - distance = distances[j] - - # Print the record and its distance - print(f" Neighbor {j + 1}: {record}, Distance: {distance:.6f}") closest_record_id = record_to_id_map[tuple(record)] - print("Record id:" + closest_record_id) closest_records.append( {"id": closest_record_id, "success": True, "created": False} ) - print("\n") # Add a newline for better readability between query results - return closest_records, None diff --git a/cumulusci/tasks/bulkdata/step.py b/cumulusci/tasks/bulkdata/step.py index 3f3fbaf0f3..b664b48ffc 100644 --- a/cumulusci/tasks/bulkdata/step.py +++ b/cumulusci/tasks/bulkdata/step.py @@ -352,6 +352,7 @@ def __init__( fields, selection_strategy=SelectStrategy.STANDARD, selection_filter=None, + content_type=None, ): super().__init__( sobject=sobject, @@ -369,12 +370,13 @@ def __init__( self.select_operation_executor = SelectOperationExecutor(selection_strategy) self.selection_filter = selection_filter + self.content_type = content_type if content_type else "CSV" def start(self): self.job_id = self.bulk.create_job( self.sobject, self.operation.value, - contentType="CSV", + contentType=self.content_type, concurrency=self.api_options.get("bulk_mode", "Parallel"), external_id_name=self.api_options.get("update_key"), ) @@ -498,31 +500,39 @@ def select_records(self, records): # Update job result based on selection outcome self.job_result = DataOperationJobResult( - status=DataOperationStatus.SUCCESS - if len(self.select_results) - else DataOperationStatus.JOB_FAILURE, + status=( + DataOperationStatus.SUCCESS + if len(self.select_results) + else DataOperationStatus.JOB_FAILURE + ), job_errors=[error_message] if error_message else [], records_processed=len(self.select_results), total_row_errors=0, ) def _execute_select_query(self, select_query: str, query_fields: List[str]): - """Executes the select Bulk API query and retrieves the results.""" + """Executes the select Bulk API query, retrieves results in JSON, and converts to CSV format if needed.""" self.batch_id = self.bulk.query(self.job_id, select_query) - self._wait_for_job(self.job_id) + self.bulk.wait_for_batch(self.job_id, self.batch_id) result_ids = self.bulk.get_query_batch_result_ids( self.batch_id, job_id=self.job_id ) select_query_records = [] + for result_id in result_ids: - uri = f"{self.bulk.endpoint}/job/{self.job_id}/batch/{self.batch_id}/result/{result_id}" + # Modify URI to request JSON format + uri = f"{self.bulk.endpoint}/job/{self.job_id}/batch/{self.batch_id}/result/{result_id}?format=json" + # Download JSON data with download_file(uri, self.bulk) as f: - reader = csv.reader(f) - self.headers = next(reader) - if "Records not found for this query" in self.headers: - break - for row in reader: - select_query_records.append(row[: len(query_fields)]) + data = json.load(f) + # Get headers from fields, expanding nested structures for TYPEOF results + self.headers = query_fields + + # Convert each record to a flat row + for record in data: + flat_record = flatten_record(record, self.headers) + select_query_records.append(flat_record) + return select_query_records def _batch(self, records, n, char_limit=10000000): @@ -641,6 +651,7 @@ def __init__( fields, selection_strategy=SelectStrategy.STANDARD, selection_filter=None, + content_type=None, ): super().__init__( sobject=sobject, @@ -655,7 +666,9 @@ def __init__( field["name"]: field for field in getattr(context.sf, sobject).describe()["fields"] } - self.boolean_fields = [f for f in fields if describe[f]["type"] == "boolean"] + self.boolean_fields = [ + f for f in fields if "." not in f and describe[f]["type"] == "boolean" + ] self.api_options = api_options.copy() self.api_options["batch_size"] = ( self.api_options.get("batch_size") or DEFAULT_REST_BATCH_SIZE @@ -666,6 +679,7 @@ def __init__( self.select_operation_executor = SelectOperationExecutor(selection_strategy) self.selection_filter = selection_filter + self.content_type = content_type def _record_to_json(self, rec): result = dict(zip(self.fields, rec)) @@ -764,9 +778,11 @@ def load_records(self, records): row_errors = len([res for res in self.results if not res["success"]]) self.job_result = DataOperationJobResult( - DataOperationStatus.SUCCESS - if not row_errors - else DataOperationStatus.ROW_FAILURE, + ( + DataOperationStatus.SUCCESS + if not row_errors + else DataOperationStatus.ROW_FAILURE + ), [], len(self.results), row_errors, @@ -775,10 +791,6 @@ def load_records(self, records): def select_records(self, records): """Executes a SOQL query to select records and adds them to results""" - def convert(rec, fields): - """Helper function to convert record values to strings, handling None values""" - return [str(rec[f]) if rec[f] is not None else "" for f in fields] - self.results = [] query_records = [] # Create a copy of the generator using tee @@ -814,17 +826,18 @@ def convert(rec, fields): response = self.sf.restful( requests.utils.requote_uri(f"query/?q={select_query}"), method="GET" ) - query_records.extend( - list(convert(rec, query_fields) for rec in response["records"]) - ) + # Convert each record to a flat row + for record in response["records"]: + flat_record = flatten_record(record, query_fields) + query_records.append(flat_record) while True: if not response["done"]: response = self.sf.query_more( response["nextRecordsUrl"], identifier_is_url=True ) - query_records.extend( - list(convert(rec, query_fields) for rec in response["records"]) - ) + for record in response["records"]: + flat_record = flatten_record(record, query_fields) + query_records.append(flat_record) else: break @@ -844,9 +857,11 @@ def convert(rec, fields): # Update the job result based on the overall selection outcome self.job_result = DataOperationJobResult( - status=DataOperationStatus.SUCCESS - if len(self.results) # Check the overall results length - else DataOperationStatus.JOB_FAILURE, + status=( + DataOperationStatus.SUCCESS + if len(self.results) # Check the overall results length + else DataOperationStatus.JOB_FAILURE + ), job_errors=[error_message] if error_message else [], records_processed=len(self.results), total_row_errors=0, @@ -988,6 +1003,7 @@ def get_dml_operation( api: Optional[DataApi] = DataApi.SMART, selection_strategy: SelectStrategy = SelectStrategy.STANDARD, selection_filter: Union[str, None] = None, + content_type: Union[str, None] = None, ) -> BaseDmlOperation: """Create an appropriate DmlOperation instance for the given parameters, selecting between REST and Bulk APIs based upon volume (Bulk used at volumes over 2000 records, @@ -1023,4 +1039,71 @@ def get_dml_operation( fields=fields, selection_strategy=selection_strategy, selection_filter=selection_filter, + content_type=content_type, ) + + +def extract_flattened_headers(query_fields): + """Extract headers from query fields, including handling of TYPEOF fields.""" + headers = [] + + for field in query_fields: + if isinstance(field, dict): + # Handle TYPEOF / polymorphic fields + for lookup, references in field.items(): + # Assuming each reference is a list of dictionaries + for ref_type in references: + for ref_obj, ref_fields in ref_type.items(): + for nested_field in ref_fields: + headers.append( + f"{lookup}.{ref_obj}.{nested_field}" + ) # Flatten the structure + else: + # Regular fields + headers.append(field) + + return headers + + +def flatten_record(record, headers): + """Flatten each record to match headers, handling nested fields.""" + flat_record = [] + + for field in headers: + components = field.split(".") + value = "" + + # Handle lookup fields with two or three components + if len(components) >= 2: + lookup_field = components[0] + lookup = record.get(lookup_field, None) + + # Check if lookup field exists in the record + if lookup is None: + value = "" + else: + if len(components) == 2: + # Handle fields with two components: {lookup}.{ref_field} + ref_field = components[1] + value = lookup.get(ref_field, "") + elif len(components) == 3: + # Handle fields with three components: {lookup}.{ref_obj}.{ref_field} + ref_obj, ref_field = components[1], components[2] + # Check if the type matches the specified ref_obj + if lookup.get("attributes", {}).get("type") == ref_obj: + value = lookup.get(ref_field, "") + else: + value = "" + + else: + # Regular fields or non-polymorphic fields + value = record.get(field, "") + + # Set None values to empty string + if value is None: + value = "" + + # Append the resolved value to the flattened record + flat_record.append(value) + + return flat_record