Skip to content

Commit

Permalink
improve creation order to preserve primary key order respective to de…
Browse files Browse the repository at this point in the history
…claration order in yaml
  • Loading branch information
briancappello-gl committed Jan 10, 2025
1 parent 96bbac4 commit d4f6fa6
Showing 1 changed file with 33 additions and 13 deletions.
46 changes: 33 additions & 13 deletions py_yaml_fixtures/fixtures_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,19 @@ def create_all(
# build up a directed acyclic graph to determine the model instantiation order
dag = nx.DiGraph()
for model_class_name, dependencies in self.relationships.items():
dag.add_node(model_class_name)
if model_class_name not in dependencies:
for dep in dependencies:
dag.add_edge(model_class_name, dep)
continue

for dep in dependencies:
associated_col_name = dependencies[dep]
for id_key, instance_data in self.model_fixtures[model_class_name].items():
identifier = Identifier(model_class_name, id_key)
dag.add_node(identifier)
dag.add_edge(model_class_name, identifier)

dag.add_node(identifier)
associated_identifiers = instance_data.get(associated_col_name)
if associated_identifiers is None:
continue
Expand All @@ -98,30 +105,42 @@ def create_all(
dag.add_edge(identifier, associated_identifier)

try:
creation_order = reversed(list(nx.topological_sort(dag)))
creation_order = list(reversed(list(nx.topological_sort(dag))))
except nx.NetworkXUnfeasible:
raise Exception('Circular dependency detected between models: ' +
', '.join('{a} -> {b}'.format(a=a, b=b)
for a, b in nx.find_cycle(dag)))

# create or update the models in the determined order
rv = []
rv = {}
for identifier in creation_order:
data = self.factory.maybe_convert_values(
identifier,
data=self.model_fixtures[identifier.class_name][identifier.key],
)
self._data_cache[identifier.class_name][identifier.key] = data
if isinstance(identifier, str):
model_class_name = identifier
keys = list(self.model_fixtures[model_class_name].keys())
else:
model_class_name = identifier.class_name
keys = [identifier.key]

for key in keys:
identifier = Identifier(model_class_name, key)
if identifier in rv:
continue

data = self.factory.maybe_convert_values(
identifier,
data=self.model_fixtures[identifier.class_name][identifier.key],
)
self._data_cache[identifier.class_name][identifier.key] = data

model_instance, created = self.factory.create_or_update(identifier, data)
if progress_callback:
progress_callback(identifier, model_instance, created)
rv.append(model_instance)
model_instance, created = self.factory.create_or_update(identifier, data)
if progress_callback:
progress_callback(identifier, model_instance, created)
rv[identifier] = model_instance

# FIXME if there are any model names in the seed files but not in creation_order

self.factory.commit()
return rv
return list(rv.values())

def convert_identifiers(self, identifiers: Union[Identifier, List[Identifier]]):
"""
Expand Down Expand Up @@ -272,6 +291,7 @@ def _ensure_env(self, env: Union[jinja2.Environment, None]):
env.globals['faker'] = faker

env.globals.setdefault('hash_password', hash_password)
env.filters.setdefault('isoformat', lambda dt: dt.isoformat())
if hasattr(jinja2, 'pass_context'):
env.globals.setdefault('random_model', jinja2.pass_context(random_model))
env.globals.setdefault('random_models', jinja2.pass_context(random_models))
Expand Down

0 comments on commit d4f6fa6

Please sign in to comment.