Skip to content

Commit

Permalink
fix: eager source row fetching logic (GoogleCloudPlatform#2071)
Browse files Browse the repository at this point in the history
This is an old bug which could only surface with the more recent
addition of custom Cypher queries.

The template tries to pre-fetch source data.

It does it for text sources, since they do not support SQL
pushdown, so their data is required for post-processing anyway.

Before this commit, the template also pre-fetched data when none
of the source's targets defined source transformations.

This is overly restrictive and actually wrong.

Custom query targets cannot define source transformations.

If they share a source with a node/rel targets that define custom
source transformations, then the template would crash with a NPE.

This is now fixed. Source data is pre-fetched as long as there is
at least one of its target that does not define any transformation.

The commit also adds another small optimization: if the source
does not match any active targets, the source processing is
skipped completely. Before that, the data could be pre-fetched,
incurring unnecessary data movement.
  • Loading branch information
fbiville authored Jan 7, 2025
1 parent 23be7bc commit 9d39b07
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@
import org.neo4j.importer.v1.targets.RelationshipTarget;
import org.neo4j.importer.v1.targets.Target;
import org.neo4j.importer.v1.targets.TargetType;
import org.neo4j.importer.v1.targets.Targets;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -275,6 +274,10 @@ public void run() {
Entry::getKey, mapping(Entry::getValue, Collectors.<PCollection<?>>toList())));
var sourceRows = new ArrayList<PCollection<?>>(importSpecification.getSources().size());
var targetRows = new HashMap<TargetType, List<PCollection<?>>>(targetCount());
var allActiveTargets =
importSpecification.getTargets().getAll().stream()
.filter(Target::isActive)
.collect(toList());
var allActiveNodeTargets =
importSpecification.getTargets().getNodes().stream()
.filter(Target::isActive)
Expand All @@ -283,40 +286,42 @@ public void run() {
////////////////////////////
// Process sources
for (var source : importSpecification.getSources()) {
String sourceName = source.getName();
var activeSourceTargets =
allActiveTargets.stream()
.filter(target -> target.getSource().equals(sourceName))
.collect(toList());
if (activeSourceTargets.isEmpty()) {
return;
}

// get provider implementation for source
Provider provider = ProviderFactory.of(source, targetSequence);
provider.configure(optionsParams);
PCollection<Row> sourceMetadata =
pipeline.apply(
String.format("Metadata for source %s", source.getName()), provider.queryMetadata());
String.format("Metadata for source %s", sourceName), provider.queryMetadata());
sourceRows.add(sourceMetadata);
Schema sourceBeamSchema = sourceMetadata.getSchema();
processingQueue.addToQueue(
ArtifactType.source, false, source.getName(), defaultActionContext, sourceMetadata);
PCollection<Row> nullableSourceBeamRows = null;
ArtifactType.source, false, sourceName, defaultActionContext, sourceMetadata);

////////////////////////////
// Optimization: if single source query, reuse this PCollection rather than write it again
boolean targetsHaveTransforms = ModelUtils.targetsHaveTransforms(importSpecification, source);
if (!targetsHaveTransforms || !provider.supportsSqlPushDown()) {
// Optimization: if some of the current source's targets either
// - do not alter the source query (i.e. define no transformations)
// - or the source provider does not support SQL pushdown
// then the source PCollection can be defined here and reused across all the relevant targets
PCollection<Row> nullableSourceBeamRows = null;
if (!provider.supportsSqlPushDown()
|| activeSourceTargets.stream()
.anyMatch(target -> !ModelUtils.targetHasTransforms(target))) {
nullableSourceBeamRows =
pipeline
.apply("Query " + source.getName(), provider.querySourceBeamRows(sourceBeamSchema))
.apply("Query " + sourceName, provider.querySourceBeamRows(sourceBeamSchema))
.setRowSchema(sourceBeamSchema);
}

String sourceName = source.getName();

////////////////////////////
// Optimization: if we're not mixing nodes and edges, then run in parallel
// For relationship updates, max workers should be max 2. This parameter is job configurable.

////////////////////////////
// No optimization possible so write nodes then edges.
// Write node targets
List<NodeTarget> nodeTargets =
getActiveTargetsBySourceAndType(importSpecification, sourceName, TargetType.NODE);
List<NodeTarget> nodeTargets = getTargetsByType(activeSourceTargets, TargetType.NODE);
for (NodeTarget target : nodeTargets) {
TargetQuerySpec targetQuerySpec =
new TargetQuerySpecBuilder()
Expand All @@ -327,7 +332,7 @@ public void run() {
String nodeStepDescription =
targetSequence.getSequenceNumber(target)
+ ": "
+ source.getName()
+ sourceName
+ "->"
+ target.getName()
+ " nodes";
Expand Down Expand Up @@ -371,7 +376,7 @@ public void run() {
////////////////////////////
// Write relationship targets
List<RelationshipTarget> relationshipTargets =
getActiveTargetsBySourceAndType(importSpecification, sourceName, TargetType.RELATIONSHIP);
getTargetsByType(activeSourceTargets, TargetType.RELATIONSHIP);
for (var target : relationshipTargets) {
var targetQuerySpec =
new TargetQuerySpecBuilder()
Expand All @@ -383,14 +388,14 @@ public void run() {
.endNodeTarget(
findNodeTargetByName(allActiveNodeTargets, target.getEndNodeReference()))
.build();
PCollection<Row> preInsertBeamRows;
String relationshipStepDescription =
targetSequence.getSequenceNumber(target)
+ ": "
+ source.getName()
+ sourceName
+ "->"
+ target.getName()
+ " edges";
PCollection<Row> preInsertBeamRows;
if (ModelUtils.targetHasTransforms(target)) {
preInsertBeamRows =
pipeline.apply(
Expand Down Expand Up @@ -439,12 +444,12 @@ public void run() {
////////////////////////////
// Custom query targets
List<CustomQueryTarget> customQueryTargets =
getActiveTargetsBySourceAndType(importSpecification, sourceName, TargetType.QUERY);
getTargetsByType(activeSourceTargets, TargetType.QUERY);
for (Target target : customQueryTargets) {
String customQueryStepDescription =
targetSequence.getSequenceNumber(target)
+ ": "
+ source.getName()
+ sourceName
+ "->"
+ target.getName()
+ " (custom query)";
Expand All @@ -455,6 +460,8 @@ public void run() {
processingQueue.waitOnCollections(
target.getDependencies(), customQueryStepDescription));

// note: nullableSourceBeamRows is guaranteed to be non-null here since custom query targets
// cannot define source transformations
PCollection<Row> blockingReturn =
nullableSourceBeamRows
.apply(
Expand Down Expand Up @@ -581,15 +588,10 @@ private static NodeTarget findNodeTargetByName(List<NodeTarget> nodes, String re
}

@SuppressWarnings("unchecked")
private <T extends Target> List<T> getActiveTargetsBySourceAndType(
ImportSpecification importSpecification, String sourceName, TargetType targetType) {
Targets targets = importSpecification.getTargets();
return targets.getAll().stream()
.filter(
target ->
target.getTargetType() == targetType
&& target.isActive()
&& sourceName.equals(target.getSource()))
private <T extends Target> List<T> getTargetsByType(
List<Target> activeSourceTargets, TargetType targetType) {
return activeSourceTargets.stream()
.filter(target -> target.getTargetType() == targetType)
.map(target -> (T) target)
.collect(toList());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
Expand All @@ -38,8 +37,6 @@
import net.sf.jsqlparser.statement.select.PlainSelect;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.commons.lang3.StringUtils;
import org.neo4j.importer.v1.ImportSpecification;
import org.neo4j.importer.v1.sources.Source;
import org.neo4j.importer.v1.targets.Aggregation;
import org.neo4j.importer.v1.targets.EntityTarget;
import org.neo4j.importer.v1.targets.NodeTarget;
Expand All @@ -58,12 +55,6 @@ public class ModelUtils {
private static final Pattern variablePattern = Pattern.compile("(\\$([a-zA-Z0-9_]+))");
private static final Logger LOG = LoggerFactory.getLogger(ModelUtils.class);

public static boolean targetsHaveTransforms(ImportSpecification jobSpec, Source source) {
return jobSpec.getTargets().getAll().stream()
.filter(target -> target.isActive() && Objects.equals(target.getSource(), source.getName()))
.anyMatch(ModelUtils::targetHasTransforms);
}

public static boolean targetHasTransforms(Target target) {
if (target.getTargetType() == TargetType.QUERY) {
return false;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright (C) 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package com.google.cloud.teleport.v2.neo4j.templates;

import static com.google.cloud.teleport.v2.neo4j.templates.Connections.jsonBasicPayload;
import static com.google.cloud.teleport.v2.neo4j.templates.Resources.contentOf;
import static org.apache.beam.it.truthmatchers.PipelineAsserts.assertThatResult;

import com.google.cloud.teleport.metadata.TemplateIntegrationTest;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import org.apache.beam.it.common.PipelineLauncher.LaunchConfig;
import org.apache.beam.it.common.PipelineLauncher.LaunchInfo;
import org.apache.beam.it.common.PipelineOperator.Result;
import org.apache.beam.it.common.TestProperties;
import org.apache.beam.it.common.utils.ResourceManagerUtils;
import org.apache.beam.it.gcp.TemplateTestBase;
import org.apache.beam.it.neo4j.Neo4jResourceManager;
import org.apache.beam.it.neo4j.conditions.Neo4jQueryCheck;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@Category(TemplateIntegrationTest.class)
@TemplateIntegrationTest(GoogleCloudToNeo4j.class)
@RunWith(JUnit4.class)
public class SyntheticFieldsIT extends TemplateTestBase {

private Neo4jResourceManager neo4jClient;

@Before
public void setup() {
neo4jClient =
Neo4jResourceManager.builder(testName)
.setAdminPassword("letmein!")
.setHost(TestProperties.hostIp())
.build();
}

@After
public void tearDown() {
ResourceManagerUtils.cleanResources(neo4jClient);
}

@Test
// TODO: generate bigquery data set once import-spec supports value interpolation
public void importsStackoverflowUsers() throws IOException {
String spec = contentOf("/testing-specs/synthetic-fields/spec.yml");
gcsClient.createArtifact("spec.yml", spec);
gcsClient.createArtifact("neo4j-connection.json", jsonBasicPayload(neo4jClient));

LaunchConfig.Builder options =
LaunchConfig.builder(testName, specPath)
.addParameter("jobSpecUri", getGcsPath("spec.yml"))
.addParameter("neo4jConnectionUri", getGcsPath("neo4j-connection.json"));
LaunchInfo info = launchTemplate(options);

Result result =
pipelineOperator()
.waitForCondition(
createConfig(info),
Neo4jQueryCheck.builder(neo4jClient)
.setQuery("MATCH (u:User) RETURN count(u) AS count")
.setExpectedResult(List.of(Map.of("count", 10L)))
.build(),
Neo4jQueryCheck.builder(neo4jClient)
.setQuery(
"MATCH (l:Letter) WITH DISTINCT toUpper(l.char) AS char ORDER BY char ASC RETURN collect(char) AS chars")
.setExpectedResult(
List.of(Map.of("chars", List.of("A", "C", "G", "I", "J", "T", "W"))))
.build());
assertThatResult(result).meetsConditions();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
version: '1'
sources:
- type: bigquery
name: so_users
# once value interpolation is supported by import-spec, this public data set query
# will be replaced by a query against a generated test bigquery data set
query: |-
SELECT id, display_name
FROM
`bigquery-public-data.stackoverflow.users`
ORDER BY id ASC
LIMIT 10
targets:
nodes:
- name: users
source: so_users
write_mode: merge
labels: [User]
source_transformations:
aggregations:
- expression: max(id)
field_name: max_id
properties:
- source_field: id
target_property: id
- source_field: display_name
target_property: name
- source_field: max_id
target_property: max_id
schema:
key_constraints:
- name: key_user_id
label: User
properties: [id]
queries:
# here we just need a custom query from the same source as another node/rel target that defines transformations
- name: user_name_starts_with
depends_on:
- users
source: so_users
query: |-
UNWIND $rows AS row
MATCH (user:User {id: row.id})
MERGE (letter:Letter {char: left(user.name, 1)})
CREATE (user)-[:NAME_STARTS_WITH]->(letter)

0 comments on commit 9d39b07

Please sign in to comment.