Skip to content

Commit

Permalink
Merge pull request #2538 from SFDO-Tooling/feature/snowfakery-1.9
Browse files Browse the repository at this point in the history
Support for Snowfakery 1.9 API
  • Loading branch information
prescod authored Apr 14, 2021
2 parents ff5e762 + eafbbe8 commit e95a332
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 45 deletions.
55 changes: 15 additions & 40 deletions cumulusci/tasks/bulkdata/generate_from_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,12 @@
import shutil
from contextlib import contextmanager

import yaml


from cumulusci.core.utils import process_list_of_pairs_dict_arg, process_list_arg

from cumulusci.core.exceptions import TaskOptionsError
from cumulusci.tasks.bulkdata.base_generate_data_task import BaseGenerateDataTask
from cumulusci.tasks.bulkdata.mapping_parser import parse_from_yaml
from snowfakery.output_streams import SqlDbOutputStream
from snowfakery.data_generator import generate, StoppingCriteria
from snowfakery.generate_mapping_from_recipe import mapping_from_recipe_templates
from snowfakery.cli import gather_declarations
from snowfakery.salesforce import create_cci_record_type_tables
from snowfakery import generate_data


class GenerateDataFromYaml(BaseGenerateDataTask):
Expand Down Expand Up @@ -91,22 +84,17 @@ def _init_options(self, kwargs):
"Cannot specify num_records without num_records_tablename."
)

self.stopping_criteria = StoppingCriteria(
num_records_tablename, num_records
)
self.stopping_criteria = (num_records, num_records_tablename)
self.working_directory = self.options.get("working_directory")
loading_rules = process_list_arg(self.options.get("loading_rules")) or []
self.loading_rules = [Path(path) for path in loading_rules if path]

def _generate_data(self, db_url, mapping_file_path, num_records, current_batch_num):
"""Generate all of the data"""
if mapping_file_path:
self.mapping = parse_from_yaml(mapping_file_path)
else:
self.mapping = {}
if num_records is not None: # num_records is None means execute Snowfakery once
self.logger.info(f"Generating batch {current_batch_num} with {num_records}")
self.generate_data(db_url, num_records, current_batch_num)
self.logger.info("Generated batch")

def default_continuation_file_path(self):
return Path(self.working_directory) / "continuation.yml"
Expand Down Expand Up @@ -155,25 +143,23 @@ def open_new_continuation_file(self):
else:
yield None

def generate_data(self, db_url, num_records, current_batch_num):
output_stream = SqlDbOutputStream.from_url(db_url, self.mapping)
def generate_data(self, dburl, num_records, current_batch_num):
old_continuation_file = self.get_old_continuation_file()
if old_continuation_file:
# reopen to ensure file pointer is at starting point
old_continuation_file = open(old_continuation_file, "r")
with self.open_new_continuation_file() as new_continuation_file:
try:
with open(self.yaml_file) as open_yaml_file:
summary = generate(
open_yaml_file=open_yaml_file,
user_options=self.vars,
output_stream=output_stream,
stopping_criteria=self.stopping_criteria,
continuation_file=old_continuation_file,
generate_continuation_file=new_continuation_file,
)
finally:
output_stream.close()
generate_data(
yaml_file=self.yaml_file,
user_options=self.vars,
target_number=self.stopping_criteria,
continuation_file=old_continuation_file,
generate_continuation_file=new_continuation_file,
generate_cci_mapping_file=self.generate_mapping_file,
dburl=dburl,
load_declarations=self.loading_rules,
should_create_cci_record_type_tables=True,
)

if (
new_continuation_file
Expand All @@ -183,14 +169,3 @@ def generate_data(self, db_url, num_records, current_batch_num):
shutil.copyfile(
new_continuation_file.name, self.default_continuation_file_path()
)

if self.generate_mapping_file:
declarations = gather_declarations(self.yaml_file, self.loading_rules)
with open(self.generate_mapping_file, "w+") as f:
yaml.safe_dump(
mapping_from_recipe_templates(summary, declarations),
f,
sort_keys=False,
)

create_cci_record_type_tables(db_url)
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,9 @@ def test_simple_generate_and_load(self, _dataload):
task()
assert len(_dataload.mock_calls) == 1

@mock.patch("cumulusci.tasks.bulkdata.generate_from_yaml.generate")
def test_exception_handled_cleanly(self, generate):
generate.side_effect = AssertionError("Foo")
@mock.patch("cumulusci.tasks.bulkdata.generate_from_yaml.generate_data")
def test_exception_handled_cleanly(self, generate_data):
generate_data.side_effect = AssertionError("Foo")
with pytest.raises(AssertionError) as e:
task = _make_task(
GenerateAndLoadDataFromYaml,
Expand All @@ -207,7 +207,7 @@ def test_exception_handled_cleanly(self, generate):
)
task()
assert "Foo" in str(e.value)
assert len(generate.mock_calls) == 1
assert len(generate_data.mock_calls) == 1

@mock.patch(
"cumulusci.tasks.bulkdata.generate_and_load_data_from_yaml.GenerateAndLoadDataFromYaml._dataload"
Expand Down
2 changes: 1 addition & 1 deletion requirements/prod.txt
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ six==1.15.0
# fs
# python-dateutil
# salesforce-bulk
snowfakery==1.8.1
snowfakery==1.9
# via -r requirements/prod.in
sqlalchemy==1.3.24
# via
Expand Down

0 comments on commit e95a332

Please sign in to comment.