Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition when loading the kryon subgraph parallelly just after VajramKryonGraph creation #330

Merged
merged 5 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
.DS_Store
**/.gradle/
**/build/
**/bin/

## From https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
import com.flipkart.krystal.krystex.kryon.KryonLogicId;
import com.flipkart.krystal.krystex.resolution.MultiResolver;
import com.flipkart.krystal.krystex.resolution.ResolverLogic;
import java.util.HashMap;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.concurrent.ConcurrentHashMap;

public final class LogicDefinitionRegistry {
private final Map<KryonLogicId, OutputLogicDefinition<?>> outputLogicDefinitions =
new HashMap<>();
new ConcurrentHashMap<>();
private final Map<KryonLogicId, LogicDefinition<ResolverLogic>> resolverLogicDefinitions =
new HashMap<>();
new ConcurrentHashMap<>();
private final Map<KryonLogicId, LogicDefinition<MultiResolver>> multiResolverDefinitions =
new HashMap<>();
new ConcurrentHashMap<>();

@SuppressWarnings("unchecked")
public <T> OutputLogicDefinition<T> getOutputLogic(KryonLogicId kryonLogicId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
import com.flipkart.krystal.krystex.resolution.Resolver;
import com.flipkart.krystal.tags.ElementTags;
import com.google.common.collect.ImmutableMap;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

public final class KryonDefinitionRegistry {

private final LogicDefinitionRegistry logicDefinitionRegistry;
private final Map<KryonId, KryonDefinition> kryonDefinitions = new LinkedHashMap<>();
private final Map<KryonId, KryonDefinition> kryonDefinitions = new ConcurrentHashMap<>();
private final DependantChainStart dependantChainStart = new DependantChainStart();

public KryonDefinitionRegistry(LogicDefinitionRegistry logicDefinitionRegistry) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,16 @@ public final class VajramKryonGraph implements VajramExecutableGraph<KrystexVajr

private final LogicDefRegistryDecorator logicRegistryDecorator;

private final Map<VajramID, VajramDefinition> vajramDefinitions = new LinkedHashMap<>();
private final ConcurrentHashMap<Class<? extends Vajram<?>>, VajramDefinition> vajramDataByClass =
private final Map<VajramID, VajramDefinition> vajramDefinitions = new ConcurrentHashMap<>();
private final Map<Class<? extends Vajram<?>>, VajramDefinition> vajramDataByClass =
new ConcurrentHashMap<>();

/** These are those call graphs of a vajram where no other vajram depends on this. */
private final Map<VajramID, KryonId> vajramExecutables = new LinkedHashMap<>();
/**
* Maps every vajramId to its corresponding kryonId all of whose dependencies have also been
* loaded recursively. The mapped kryon id represents the complete executable sub-graph of the
* vajram.
*/
private final Map<VajramID, KryonId> vajramExecutables = new ConcurrentHashMap<>();

/** LogicDecorator Id -> LogicDecoratorConfig */
private final ImmutableMap<String, OutputLogicDecoratorConfig> sessionScopedDecoratorConfigs;
Expand Down Expand Up @@ -171,7 +175,7 @@ public void registerInputBatchers(VajramID vajramID, InputBatcherConfig... input
*/
public DependantChain computeDependantChain(
String firstVajramId, Dependency firstDependencyId, Dependency... subsequentDependencyIds) {
KryonId firstKryonId = _getVajramExecutionGraph(vajramID(firstVajramId));
KryonId firstKryonId = getKryonId(vajramID(firstVajramId));
KryonDefinition currentKryon = kryonDefinitionRegistry.get(firstKryonId);
DependantChain currentDepChain =
kryonDefinitionRegistry.getDependantChainsStart().extend(firstKryonId, firstDependencyId);
Expand Down Expand Up @@ -221,57 +225,65 @@ private void registerVajram(Vajram<Object> vajram) {
* @return {@link KryonId} of the {@link KryonDefinition} corresponding to this given vajramId
*/
KryonId getKryonId(VajramID vajramId) {
return _getVajramExecutionGraph(vajramId);
}

private KryonId _getVajramExecutionGraph(VajramID vajramId) {
KryonId kryonId = vajramExecutables.get(vajramId);
if (kryonId != null) {
return kryonId;
} else {
return loadKryonSubgraph(vajramId, new LinkedHashMap<>());
}
}

private KryonId loadKryonSubgraph(VajramID vajramId, Map<VajramID, KryonId> loadingInProgress) {
synchronized (vajramExecutables) {
KryonId kryonId;
if ((kryonId = vajramExecutables.get(vajramId)) != null) {
// This means the subgraph is already loaded.
return kryonId;
} else if ((kryonId = loadingInProgress.get(vajramId)) != null) {
// This means the subgraph is still being loaded, but there is a cyclic dependency. Just
// return the kryon to prevent infinite recursion.
return kryonId;
}
kryonId = new KryonId(vajramId.vajramId());
// add to loadingInProgress so that this can be used to prevent infinite recursion in the
// cases where a vajram depends on itself in a cyclic dependency.
loadingInProgress.put(vajramId, kryonId);
VajramDefinition vajramDefinition =
getVajramDefinition(vajramId)
.orElseThrow(
() ->
new NoSuchElementException(
"Could not find vajram with id: %s".formatted(vajramId)));
InputResolverCreationResult inputResolverCreationResult =
createKryonLogicsForInputResolvers(vajramDefinition);
ImmutableMap<Dependency, KryonId> depIdToProviderKryon =
createKryonDefinitionsForDependencies(vajramDefinition, loadingInProgress);
OutputLogicDefinition<?> outputLogicDefinition =
createKryonOutputLogic(kryonId, vajramDefinition);
ImmutableSet<? extends Facet> inputIds = vajramDefinition.facetSpecs();
LogicDefinition<CreateNewRequest> createNewRequest =
new LogicDefinition<>(
new KryonLogicId(kryonId, "%s:newRequest"),
ImmutableSet.of(),
emptyTags(),
vajramDefinition.vajram()::newRequestBuilder);
KryonDefinition kryonDefinition =
kryonDefinitionRegistry.newKryonDefinition(
kryonId.value(),
inputIds,
outputLogicDefinition.kryonLogicId(),
depIdToProviderKryon,
inputResolverCreationResult.resolversByDefinition(),
createNewRequest,
new LogicDefinition<>(
new KryonLogicId(kryonId, "%s:facetsFromRequest"),
ImmutableSet.of(),
emptyTags(),
r -> vajramDefinition.vajram().facetsFromRequest(r)),
vajramDefinition.vajramTags());
vajramExecutables.put(vajramId, kryonId);
return kryonDefinition.kryonId();
}
vajramExecutables.put(vajramId, kryonId);

VajramDefinition vajramDefinition =
getVajramDefinition(vajramId)
.orElseThrow(
() ->
new NoSuchElementException(
"Could not find vajram with id: %s".formatted(vajramId)));

InputResolverCreationResult inputResolverCreationResult =
createKryonLogicsForInputResolvers(vajramDefinition);

ImmutableMap<Dependency, KryonId> depIdToProviderKryon =
createKryonDefinitionsForDependencies(vajramDefinition);

OutputLogicDefinition<?> outputLogicDefinition =
createKryonOutputLogic(kryonId, vajramDefinition);

ImmutableSet<? extends Facet> inputIds = vajramDefinition.facetSpecs();

LogicDefinition<CreateNewRequest> createNewRequest =
new LogicDefinition<>(
new KryonLogicId(kryonId, "%s:newRequest"),
ImmutableSet.of(),
emptyTags(),
vajramDefinition.vajram()::newRequestBuilder);
KryonDefinition kryonDefinition =
kryonDefinitionRegistry.newKryonDefinition(
kryonId.value(),
inputIds,
outputLogicDefinition.kryonLogicId(),
depIdToProviderKryon,
inputResolverCreationResult.resolversByDefinition(),
createNewRequest,
new LogicDefinition<>(
new KryonLogicId(kryonId, "%s:facetsFromRequest"),
ImmutableSet.of(),
emptyTags(),
r -> vajramDefinition.vajram().facetsFromRequest(r)),
vajramDefinition.vajramTags());
return kryonDefinition.kryonId();
}

private InputResolverCreationResult createKryonLogicsForInputResolvers(
Expand Down Expand Up @@ -448,7 +460,7 @@ private OutputLogicDefinition<?> createKryonOutputLogic(
}

private ImmutableMap<Dependency, KryonId> createKryonDefinitionsForDependencies(
VajramDefinition vajramDefinition) {
VajramDefinition vajramDefinition, Map<VajramID, KryonId> loadingInProgress) {
List<DependencySpec> dependencies = new ArrayList<>();
for (Facet facet : vajramDefinition.facetSpecs()) {
if (facet instanceof DependencySpec definition) {
Expand All @@ -464,7 +476,8 @@ private ImmutableMap<Dependency, KryonId> createKryonDefinitionsForDependencies(
throw new VajramDefinitionException(
"Unable to find vajram for vajramId %s".formatted(accessSpec));
}
depIdToProviderKryon.put(dependency, _getVajramExecutionGraph(dependencyVajram.vajramId()));
depIdToProviderKryon.put(
dependency, loadKryonSubgraph(dependencyVajram.vajramId(), loadingInProgress));
}
return ImmutableMap.copyOf(depIdToProviderKryon);
}
Expand Down
Loading