-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #342 from Living-with-machines/339-postproc
Add context-based post processing for linear features
- Loading branch information
Showing
5 changed files
with
536 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,94 @@ | ||
Post-process | ||
============= | ||
|
||
TBC | ||
MapReader post-processing's sub-package currently contains one method for post-processing the predictions from your model based on the idea that features such as railways, roads, coastlines, etc. are continuous and so patches with these labels should be found near to other patches also with these labels. | ||
For example, if a patch is predicted to be a railspace, but is surrounded by patches predicted to be non-railspace, then it is likely that the railspace patch is a false positive. | ||
|
||
To implement this, for a given patch, the code checks whether any of the 8 surrounding patches have the same label (e.g. 'railspace') and, if not, assumes the current patch's predicted label to be a false positive. | ||
The user can then choose how to relabel the patch (e.g. 'railspace' -> 'no'). | ||
|
||
To run the post-processing code, you will need to have saved the predictions from your model in the format expected for the post-processing code. | ||
See the :doc:`/User-guide/Classify/Classify` docs for more on this. | ||
|
||
If you have your predictions saved in a csv file, you will first need to load them into a pandas DataFrame: | ||
|
||
.. code-block:: python | ||
import pandas as pd | ||
preds = pd.read_csv("path/to/predictions.csv", index_col=0) | ||
You can then run the post-processing code as follows: | ||
|
||
.. code-block:: python | ||
from mapreader.process.post_process import PatchDataFrame | ||
labels_map = { | ||
0: "no", | ||
1: "railspace", | ||
2: "building", | ||
3: "railspace&building" | ||
} | ||
patches = PatchDataFrame(preds, labels_map=labels_map) | ||
MapReader's post-processing will only work for features that are expected be continuous (e.g. railway, road, coastline, etc.) or clustered (e.g. a large body of water). | ||
You will need to tell MapReader which labels to select and then get the context for each of the relevant patches in order to work out if it is isolated or part of a line/cluster. | ||
|
||
For example, if you want to post-process patches which are predicted to be 'railspace' or 'railspace&building', you would do the following: | ||
|
||
.. code-block:: python | ||
labels=["railspace", "railspace&building"] | ||
patches.get_context(labels=labels) | ||
.. note:: In the above example, we needed to use both 'railspace' and 'railspace&building' as our labels, since the continuous feature we are trying to post-process is railway lines (included in both these labels). | ||
|
||
You will also need to tell MapReader how to update the label of each patch that is isolated and therefore likely to be a false positive. | ||
This is done using the ``remap`` argument, which takes a dictionary of the form ``{old_label: new_label}``. | ||
|
||
For example, if you want to remap all isolated 'railspace' patches to be labelled as 'no', and all isolated 'railspace&building' patches to be labelled as 'building', you would do the following: | ||
|
||
.. code-block:: python | ||
remap={"railspace": "no", "railspace&building": "building"} | ||
patches.update_preds(remap=remap) | ||
By default, only patches with model confidence of below 0.7 will be relabelled. | ||
You can adjust this by passing the ``conf`` argument. | ||
|
||
e.g. to relabel all isolated patches with confidence below 0.9, you would do the following: | ||
|
||
.. code-block:: python | ||
remap={"railspace": "no", "railspace&building": "building"} | ||
patches.update_preds(remap=remap, conf=0.9) | ||
Instead of relabelling your chosen patches to an existing label, you can also choose to relabel them to a new label. | ||
For example, to mark them as 'false_positive', you would do the following: | ||
|
||
.. code-block:: python | ||
remap={"railspace": "false_positive", "railspace&building": "false_positive"} | ||
patches.update_preds(remap=remap) | ||
By default, after running `update_preds`, a new column will be added to your ``patches`` DataFrame called "new_predicted_label". | ||
This will contain the updated predictions (or NaN if the patch was not relabelled). | ||
|
||
Alternatively, to save the updated predictions inplace you can pass the ``inplace`` argument: | ||
|
||
.. code-block:: python | ||
remap={"railspace": "no", "railspace&building": "building"} | ||
patches.update_preds(remap=remap, inplace=True) | ||
Finally, to save your outputs to a csv file, you can do the following: | ||
|
||
.. code-block:: python | ||
patches.to_csv("path/to/save/updated_predictions.csv") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
#!/usr/bin/env python | ||
from __future__ import annotations | ||
|
||
from ast import literal_eval | ||
from itertools import product | ||
|
||
import pandas as pd | ||
from tqdm import tqdm | ||
|
||
|
||
class PatchDataFrame(pd.DataFrame): | ||
"""A class for storing patch dataframes. | ||
Parameters | ||
---------- | ||
patch_df : pd.DataFrame | ||
the DataFrame containing patches and predictions | ||
labels_map : dict | ||
the dictionary mapping label indices to their labels. | ||
e.g. `{0: "no", 1: "railspace"}`. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
patch_df: pd.DataFrame, | ||
labels_map: dict, | ||
): | ||
required_columns = [ | ||
"parent_id", | ||
"pixel_bounds", | ||
"pred", | ||
"predicted_label", | ||
"conf", | ||
] | ||
if not all([col in patch_df.columns for col in required_columns]): | ||
raise ValueError( | ||
f"[ERROR] Your dataframe must contain the following columns: {required_columns}." | ||
) | ||
|
||
# ensure lists/tuples are evaluated as such | ||
for col in patch_df.columns: | ||
try: | ||
patch_df[col] = patch_df[col].apply(literal_eval) | ||
except (ValueError, TypeError, SyntaxError): | ||
pass | ||
|
||
if patch_df.index.name != "image_id" and "image_id" in patch_df.columns: | ||
patch_df.set_index("image_id", drop=True, inplace=True) | ||
|
||
if all( | ||
[col in patch_df.columns for col in ["min_x", "min_y", "max_x", "max_y"]] | ||
): | ||
print( | ||
"[INFO] Using existing pixel bounds columns (min_x, min_y, max_x, max_y)." | ||
) | ||
else: | ||
patch_df[["min_x", "min_y", "max_x", "max_y"]] = [*patch_df["pixel_bounds"]] | ||
|
||
super().__init__(patch_df) | ||
|
||
self.labels_map = labels_map | ||
self._label_patches = None | ||
self.context = {} | ||
|
||
def get_context( | ||
self, | ||
labels: str | list, | ||
): | ||
"""Get the context of the patches with the specified labels. | ||
Parameters | ||
---------- | ||
labels : str | list | ||
The label(s) to get context for. | ||
""" | ||
if isinstance(labels, str): | ||
labels = [labels] | ||
self._label_patches = self[self["predicted_label"].isin(labels)] | ||
|
||
for id in tqdm(self._label_patches.index): | ||
if id not in self.context: | ||
context_list = self._get_context_id(id) | ||
# only add context if all surrounding patches are found | ||
if len(context_list) == 9: | ||
self.context[id] = context_list | ||
|
||
def _get_context_id( | ||
self, | ||
id, | ||
): | ||
"""Get the context of the patch with the specified index.""" | ||
parent_id = self.loc[id, "parent_id"] | ||
min_x = self.loc[id, "min_x"] | ||
min_y = self.loc[id, "min_y"] | ||
max_x = self.loc[id, "max_x"] | ||
max_y = self.loc[id, "max_y"] | ||
|
||
context_grid = [ | ||
*product( | ||
[(self["min_x"], min_x), (min_x, max_x), (max_x, self["max_x"])], | ||
[(self["min_y"], min_y), (min_y, max_y), (max_y, self["max_y"])], | ||
) | ||
] | ||
# reshape to min_x, min_y, max_x, max_y | ||
context_grid = [(x[0][0], x[1][0], x[0][1], x[1][1]) for x in context_grid] | ||
|
||
context_list = [ | ||
self[ | ||
(self["min_x"] == context_loc[0]) | ||
& (self["min_y"] == context_loc[1]) | ||
& (self["max_x"] == context_loc[2]) | ||
& (self["max_y"] == context_loc[3]) | ||
& (self["parent_id"] == parent_id) | ||
] | ||
for context_loc in context_grid | ||
] | ||
context_list = [x.index[0] for x in context_list if len(x)] | ||
return context_list | ||
|
||
def update_preds(self, remap: dict, conf: float = 0.7, inplace: bool = False): | ||
"""Update the predictions of the chosen patches based on their context. | ||
Parameters | ||
---------- | ||
remap : dict | ||
A dictionary mapping the old labels to the new labels. | ||
conf : float, optional | ||
Patches with confidence scores below this value will be relabelled, by default 0.7. | ||
inplace : bool, optional | ||
Whether to relabel inplace or create new dataframe columns, by default False | ||
""" | ||
if self._label_patches is None: | ||
raise ValueError("[ERROR] You must run `get_context` first.") | ||
if len(self.context) == 0: | ||
raise ValueError( | ||
"[ERROR] No patches to update. Try changing which labels you are updating." | ||
) | ||
|
||
labels = self._label_patches["predicted_label"].unique() | ||
if any([label not in remap.keys() for label in labels]): | ||
raise ValueError( | ||
f"[ERROR] You must specify a remap for each label in {labels}." | ||
) | ||
|
||
# add new label to labels_map if not already present (assume label index is next in sequence) | ||
for new_label in remap.values(): | ||
if new_label not in self.labels_map.values(): | ||
print( | ||
[ | ||
f"[INFO] Adding {new_label} to labels_map at index {len(self.labels_map)}." | ||
] | ||
) | ||
self.labels_map[len(self.labels_map)] = new_label | ||
|
||
for id in tqdm(self.context): | ||
self._update_preds_id( | ||
id, labels=labels, remap=remap, conf=conf, inplace=inplace | ||
) | ||
|
||
def _update_preds_id( | ||
self, id, labels: str | list, remap: dict, conf: float, inplace: bool = False | ||
): | ||
"""Update the predictions of the patch with the specified index.""" | ||
context_list = self.context[id] | ||
|
||
context_df = self[self.index.isin(context_list)] | ||
# drop central patch from context | ||
context_df.drop(index=id, inplace=True) | ||
|
||
# reverse the labels_map dict | ||
label_index_dict = {v: k for k, v in self.labels_map.items()} | ||
|
||
prefix = "" if inplace else "new_" | ||
if (not any(context_df["predicted_label"].isin(labels))) & ( | ||
self.loc[id, "conf"] < conf | ||
): | ||
self.loc[id, f"{prefix}predicted_label"] = remap[ | ||
self.loc[id, "predicted_label"] | ||
] | ||
self.loc[id, f"{prefix}pred"] = label_index_dict[ | ||
self.loc[id, f"{prefix}predicted_label"] | ||
] |
Oops, something went wrong.