Skip to content

Commit

Permalink
Update input to be marginally less brittle
Browse files Browse the repository at this point in the history
  • Loading branch information
brynpickering committed May 13, 2021
1 parent 160b99b commit fe2773e
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 25 deletions.
4 changes: 2 additions & 2 deletions rules/rule_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ def collect_shape_dirs(config, rules):
inputs = {}
for shape in shapes:
if shape in ["nuts", "gadm", "lau"]:
inputs[shape] = getattr(rules, f"administrative_borders_{shape}").output[0]
inputs[f"SHAPEINPUT_{shape}"] = getattr(rules, f"administrative_borders_{shape}").output[0]
else:
inputs[shape] = config["data-sources"][shape]
inputs[f"SHAPEINPUT_{shape}"] = config["data-sources"][shape]
return inputs
5 changes: 4 additions & 1 deletion scripts/administrative_borders.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ def normalise_admin_borders(crs, scope_config, path_to_output, **shape_dirs):


if __name__ == "__main__":
shape_dirs = {k: v for k, v in snakemake.input.items() if k != "src"}
shape_dirs = {
source.replace("SHAPEINPUTS_", ""): source_dir
for source, source_dir in snakemake.input.items() if source.startswith("SHAPEINPUTS")
}
normalise_admin_borders(
crs=snakemake.params.crs,
scope_config=snakemake.params.scope,
Expand Down
59 changes: 37 additions & 22 deletions tests/rules/test_rule_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@


class TestCustomShapes:
PREFIX = "SHAPEINPUT_"
@pytest.fixture
def rules(self):
# This emulates a snakemake rule object, for lack of access to one in the tests
class DotDict(dict):
"""dot.notation access to dictionary attributes"""
def __getattr__(*args):
Expand All @@ -19,47 +21,60 @@ def __getattr__(*args):
"administrative_borders_lau": {"output": ["you_got_lau"]}
})

@pytest.mark.parametrize("number", (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 100, 1000, 5555))
def test_strip(self, rules, number):
config = {"layers": {"national": {"foo": "nuts{}".format(number)}}}
collected = collect_shape_dirs(config, rules)
assert collected == {"nuts": "you_got_nuts"}

@pytest.mark.parametrize("number", (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 100, 1000, 5555))
def test_no_strip(self, rules, number):
config = {"layers": {"national": {"foo": "{0}bar{0}".format(number)}}, "data-sources": {f"{number}bar": "foobar"}}
collected = collect_shape_dirs(config, rules)
assert collected == {f"{number}bar": "foobar"}

def test_standard_shapes(self, rules):
@pytest.fixture
def standard_shapes(self, rules):
config = {
"layers": {
"national": {"foo": "nuts2", "bar": "nuts3", "baz": "gadm0"},
"custom": {"foo": "nuts2", "bar": "lau2", "baz": "gadm0"}
},
"data-sources": {"my-custom-shape": "data/dir.geojson"}
}
collected = collect_shape_dirs(config, rules)
assert set(collected.keys()) == set(["lau", "nuts", "gadm"])
for key in ["lau", "nuts", "gadm"]:
assert collected[key] == f"you_got_{key}"
return collect_shape_dirs(config, rules)

def test_new_shapes(self, rules):
@pytest.fixture
def new_shapes(self, rules):
config = {
"layers": {
"national": {"foo": "my-custom-shape1", "bar": "nuts3", "baz": "my-custom-shape2"},
"custom": {"foo": "your-custom-shape2", "bar": "my-custom-shape12", "baz": "gadm0"}
},
"data-sources": {"my-custom-shape": "mydata/dir.geojson", "your-custom-shape": "yourdata/dir.geojson"}
}
return collect_shape_dirs(config, rules)

@pytest.mark.parametrize("number", (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 100, 1000, 5555))
def test_strip_trailing_number(self, rules, number):
config = {"layers": {"national": {"foo": "nuts{}".format(number)}}}
collected = collect_shape_dirs(config, rules)
assert collected == {f"{self.PREFIX}nuts": "you_got_nuts"}

@pytest.mark.parametrize("number", (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 100, 1000, 5555))
def test_no_strip(self, rules, number):
config = {"layers": {"national": {"foo": "{0}bar{0}".format(number)}}, "data-sources": {f"{number}bar": "foobar"}}
collected = collect_shape_dirs(config, rules)
assert set(collected.keys()) == set(["my-custom-shape", "your-custom-shape", "nuts", "gadm"])
assert collected == {f"{self.PREFIX}{number}bar": "foobar"}


def test_collected_standard_shape_sources(self, standard_shapes):
assert set(standard_shapes.keys()) == set(f"{self.PREFIX}{i}" for i in ["lau", "nuts", "gadm"])

def test_collected_standard_shapes(self, standard_shapes):
for key in ["lau", "nuts", "gadm"]:
assert standard_shapes[f"{self.PREFIX}{key}"] == f"you_got_{key}"

def test_collected_new_shape_sources(self, new_shapes):
assert set(new_shapes.keys()) == set(f"{self.PREFIX}{i}" for i in ["my-custom-shape", "your-custom-shape", "nuts", "gadm"])

def test_collected_new_shapes_when_defined(self, new_shapes):
for key in ["my", "your"]:
assert collected[f"{key}-custom-shape"] == f"{key}data/dir.geojson"
assert new_shapes[f"{self.PREFIX}{key}-custom-shape"] == f"{key}data/dir.geojson"

def test_collected_standard_shapes_when_new_shapes_defined(self, new_shapes):
for key in ["nuts", "gadm"]:
assert collected[key] == f"you_got_{key}"
assert new_shapes[f"{self.PREFIX}{key}"] == f"you_got_{key}"

def test_will_fail(self, rules):
def test_no_valid_data_source_for_shape(self, rules):
config = {
"layers": {
"national": {"foo": "my-custom-shape", "bar": "nuts3"},
Expand Down

0 comments on commit fe2773e

Please sign in to comment.