-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add context-based post processing for linear features #342
Merged
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
02e2436
enable easier saving of predictions to csv
rwood-97 60641bb
add post processing script
rwood-97 f6f5e89
add docstrings, allow user to specify conf
rwood-97 08136a4
skip edge patches, allow new labels
rwood-97 9b9003c
force image_id index
rwood-97 3c58460
add tests
rwood-97 f668a73
Add post-processing docs
rwood-97 1abce20
add suggestion
rwood-97 777c857
Merge branch 'main' into 339-postproc
rwood-97 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,94 @@ | ||
Post-process | ||
============= | ||
|
||
TBC | ||
MapReader post-processing's sub-package currently contains one method for post-processing the predictions from your model based on the idea that features such as railways, roads, coastlines, etc. are continuous and so patches with these labels should be found near to other patches also with these labels. | ||
For example, if a patch is predicted to be a railspace, but is surrounded by patches predicted to be non-railspace, then it is likely that the railspace patch is a false positive. | ||
|
||
To implement this, for a given patch, the code checks whether any of the 8 surrounding patches have the same label (e.g. 'railspace') and, if not, assumes the current patch's predicted label to be a false positive. | ||
The user can then choose how to relabel the patch (e.g. 'railspace' -> 'no'). | ||
|
||
To run the post-processing code, you will need to have saved the predictions from your model in the format expected for the post-processing code. | ||
See the :doc:`/User-guide/Classify/Classify` docs for more on this. | ||
|
||
If you have your predictions saved in a csv file, you will first need to load them into a pandas DataFrame: | ||
|
||
.. code-block:: python | ||
|
||
import pandas as pd | ||
|
||
preds = pd.read_csv("path/to/predictions.csv", index_col=0) | ||
|
||
|
||
You can then run the post-processing code as follows: | ||
|
||
.. code-block:: python | ||
|
||
from mapreader.process.post_process import PatchDataFrame | ||
|
||
labels_map = { | ||
0: "no", | ||
1: "railspace", | ||
2: "building", | ||
3: "railspace&building" | ||
} | ||
|
||
patches = PatchDataFrame(preds, labels_map=labels_map) | ||
|
||
MapReader's post-processing will only work for features that are expected be continuous (e.g. railway, road, coastline, etc.) or clustered (e.g. a large body of water). | ||
You will need to tell MapReader which labels to select and then get the context for each of the relevant patches in order to work out if it is isolated or part of a line/cluster. | ||
|
||
For example, if you want to post-process patches which are predicted to be 'railspace' or 'railspace&building', you would do the following: | ||
|
||
.. code-block:: python | ||
|
||
labels=["railspace", "railspace&building"] | ||
patches.get_context(labels=labels) | ||
|
||
|
||
.. note:: In the above example, we needed to use both 'railspace' and 'railspace&building' as our labels, since the continuous feature we are trying to post-process is railway lines (included in both these labels). | ||
|
||
You will also need to tell MapReader how to update the label of each patch that is isolated and therefore likely to be a false positive. | ||
This is done using the ``remap`` argument, which takes a dictionary of the form ``{old_label: new_label}``. | ||
|
||
For example, if you want to remap all isolated 'railspace' patches to be labelled as 'no', and all isolated 'railspace&building' patches to be labelled as 'building', you would do the following: | ||
|
||
.. code-block:: python | ||
|
||
remap={"railspace": "no", "railspace&building": "building"} | ||
patches.update_preds(remap=remap) | ||
|
||
By default, only patches with model confidence of below 0.7 will be relabelled. | ||
You can adjust this by passing the ``conf`` argument. | ||
|
||
e.g. to relabel all isolated patches with confidence below 0.9, you would do the following: | ||
|
||
.. code-block:: python | ||
|
||
remap={"railspace": "no", "railspace&building": "building"} | ||
patches.update_preds(remap=remap, conf=0.9) | ||
|
||
Instead of relabelling your chosen patches to an existing label, you can also choose to relabel them to a new label. | ||
For example, to mark them as 'false_positive', you would do the following: | ||
|
||
.. code-block:: python | ||
|
||
remap={"railspace": "false_positive", "railspace&building": "false_positive"} | ||
patches.update_preds(remap=remap) | ||
|
||
|
||
By default, after running `update_preds`, a new column will be added to your ``patches`` DataFrame called "new_predicted_label". | ||
This will contain the updated predictions (or NaN if the patch was not relabelled). | ||
|
||
Alternatively, to save the updated predictions inplace you can pass the ``inplace`` argument: | ||
|
||
.. code-block:: python | ||
|
||
remap={"railspace": "no", "railspace&building": "building"} | ||
patches.update_preds(remap=remap, inplace=True) | ||
|
||
|
||
Finally, to save your outputs to a csv file, you can do the following: | ||
|
||
.. code-block:: python | ||
|
||
patches.to_csv("path/to/save/updated_predictions.csv") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
#!/usr/bin/env python | ||
from __future__ import annotations | ||
|
||
from ast import literal_eval | ||
from itertools import product | ||
|
||
import pandas as pd | ||
from tqdm import tqdm | ||
|
||
|
||
class PatchDataFrame(pd.DataFrame): | ||
"""A class for storing patch dataframes. | ||
|
||
Parameters | ||
---------- | ||
patch_df : pd.DataFrame | ||
the DataFrame containing patches and predictions | ||
labels_map : dict | ||
the dictionary mapping label indices to their labels. | ||
e.g. `{0: "no", 1: "railspace"}`. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
patch_df: pd.DataFrame, | ||
labels_map: dict, | ||
): | ||
required_columns = [ | ||
"parent_id", | ||
"pixel_bounds", | ||
"pred", | ||
"predicted_label", | ||
"conf", | ||
] | ||
if not all([col in patch_df.columns for col in required_columns]): | ||
raise ValueError( | ||
f"[ERROR] Your dataframe must contain the following columns: {required_columns}." | ||
) | ||
|
||
# ensure lists/tuples are evaluated as such | ||
for col in patch_df.columns: | ||
try: | ||
patch_df[col] = patch_df[col].apply(literal_eval) | ||
except (ValueError, TypeError, SyntaxError): | ||
pass | ||
|
||
if patch_df.index.name != "image_id" and "image_id" in patch_df.columns: | ||
patch_df.set_index("image_id", drop=True, inplace=True) | ||
|
||
if all( | ||
[col in patch_df.columns for col in ["min_x", "min_y", "max_x", "max_y"]] | ||
): | ||
print( | ||
"[INFO] Using existing pixel bounds columns (min_x, min_y, max_x, max_y)." | ||
) | ||
else: | ||
patch_df[["min_x", "min_y", "max_x", "max_y"]] = [*patch_df["pixel_bounds"]] | ||
|
||
super().__init__(patch_df) | ||
|
||
self.labels_map = labels_map | ||
self._label_patches = None | ||
self.context = {} | ||
|
||
def get_context( | ||
self, | ||
labels: str | list, | ||
): | ||
"""Get the context of the patches with the specified labels. | ||
|
||
Parameters | ||
---------- | ||
labels : str | list | ||
The label(s) to get context for. | ||
""" | ||
if isinstance(labels, str): | ||
labels = [labels] | ||
self._label_patches = self[self["predicted_label"].isin(labels)] | ||
|
||
for id in tqdm(self._label_patches.index): | ||
if id not in self.context: | ||
context_list = self._get_context_id(id) | ||
# only add context if all surrounding patches are found | ||
if len(context_list) == 9: | ||
self.context[id] = context_list | ||
|
||
def _get_context_id( | ||
self, | ||
id, | ||
): | ||
"""Get the context of the patch with the specified index.""" | ||
parent_id = self.loc[id, "parent_id"] | ||
min_x = self.loc[id, "min_x"] | ||
min_y = self.loc[id, "min_y"] | ||
max_x = self.loc[id, "max_x"] | ||
max_y = self.loc[id, "max_y"] | ||
|
||
context_grid = [ | ||
*product( | ||
[(self["min_x"], min_x), (min_x, max_x), (max_x, self["max_x"])], | ||
[(self["min_y"], min_y), (min_y, max_y), (max_y, self["max_y"])], | ||
) | ||
] | ||
# reshape to min_x, min_y, max_x, max_y | ||
context_grid = [(x[0][0], x[1][0], x[0][1], x[1][1]) for x in context_grid] | ||
|
||
context_list = [ | ||
self[ | ||
(self["min_x"] == context_loc[0]) | ||
& (self["min_y"] == context_loc[1]) | ||
& (self["max_x"] == context_loc[2]) | ||
& (self["max_y"] == context_loc[3]) | ||
& (self["parent_id"] == parent_id) | ||
] | ||
for context_loc in context_grid | ||
] | ||
context_list = [x.index[0] for x in context_list if len(x)] | ||
return context_list | ||
|
||
def update_preds(self, remap: dict, conf: float = 0.7, inplace: bool = False): | ||
"""Update the predictions of the chosen patches based on their context. | ||
|
||
Parameters | ||
---------- | ||
remap : dict | ||
A dictionary mapping the old labels to the new labels. | ||
conf : float, optional | ||
Patches with confidence scores below this value will be relabelled, by default 0.7. | ||
inplace : bool, optional | ||
Whether to relabel inplace or create new dataframe columns, by default False | ||
""" | ||
if self._label_patches is None: | ||
raise ValueError("[ERROR] You must run `get_context` first.") | ||
if len(self.context) == 0: | ||
raise ValueError( | ||
"[ERROR] No patches to update. Try changing which labels you are updating." | ||
) | ||
|
||
labels = self._label_patches["predicted_label"].unique() | ||
if any([label not in remap.keys() for label in labels]): | ||
raise ValueError( | ||
f"[ERROR] You must specify a remap for each label in {labels}." | ||
) | ||
|
||
# add new label to labels_map if not already present (assume label index is next in sequence) | ||
for new_label in remap.values(): | ||
if new_label not in self.labels_map.values(): | ||
print( | ||
[ | ||
f"[INFO] Adding {new_label} to labels_map at index {len(self.labels_map)}." | ||
] | ||
) | ||
self.labels_map[len(self.labels_map)] = new_label | ||
|
||
for id in tqdm(self.context): | ||
self._update_preds_id( | ||
id, labels=labels, remap=remap, conf=conf, inplace=inplace | ||
) | ||
|
||
def _update_preds_id( | ||
self, id, labels: str | list, remap: dict, conf: float, inplace: bool = False | ||
): | ||
"""Update the predictions of the patch with the specified index.""" | ||
context_list = self.context[id] | ||
|
||
context_df = self[self.index.isin(context_list)] | ||
# drop central patch from context | ||
context_df.drop(index=id, inplace=True) | ||
|
||
# reverse the labels_map dict | ||
label_index_dict = {v: k for k, v in self.labels_map.items()} | ||
|
||
prefix = "" if inplace else "new_" | ||
if (not any(context_df["predicted_label"].isin(labels))) & ( | ||
rwood-97 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.loc[id, "conf"] < conf | ||
): | ||
self.loc[id, f"{prefix}predicted_label"] = remap[ | ||
self.loc[id, "predicted_label"] | ||
] | ||
self.loc[id, f"{prefix}pred"] = label_index_dict[ | ||
self.loc[id, f"{prefix}predicted_label"] | ||
] |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess you could be even more explicit and say: "The current method checks whether any of the 8 surrounding patches have the same label as a given patch (e.g. railspace), and if not, assumes this to be a false positive".
Perhaps could also mention: "Future releases may add functionality to create custom filter rules for your use case"