Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add context-based post processing for linear features #342

Merged
merged 9 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 91 additions & 1 deletion docs/source/User-guide/Post-process.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,94 @@
Post-process
=============

TBC
MapReader post-processing's sub-package currently contains one method for post-processing the predictions from your model based on the idea that features such as railways, roads, coastlines, etc. are continuous and so patches with these labels should be found near to other patches also with these labels.
For example, if a patch is predicted to be a railspace, but is surrounded by patches predicted to be non-railspace, then it is likely that the railspace patch is a false positive.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you could be even more explicit and say: "The current method checks whether any of the 8 surrounding patches have the same label as a given patch (e.g. railspace), and if not, assumes this to be a false positive".

Perhaps could also mention: "Future releases may add functionality to create custom filter rules for your use case"


To implement this, for a given patch, the code checks whether any of the 8 surrounding patches have the same label (e.g. 'railspace') and, if not, assumes the current patch's predicted label to be a false positive.
The user can then choose how to relabel the patch (e.g. 'railspace' -> 'no').

To run the post-processing code, you will need to have saved the predictions from your model in the format expected for the post-processing code.
See the :doc:`/User-guide/Classify/Classify` docs for more on this.

If you have your predictions saved in a csv file, you will first need to load them into a pandas DataFrame:

.. code-block:: python

import pandas as pd

preds = pd.read_csv("path/to/predictions.csv", index_col=0)


You can then run the post-processing code as follows:

.. code-block:: python

from mapreader.process.post_process import PatchDataFrame

labels_map = {
0: "no",
1: "railspace",
2: "building",
3: "railspace&building"
}

patches = PatchDataFrame(preds, labels_map=labels_map)

MapReader's post-processing will only work for features that are expected be continuous (e.g. railway, road, coastline, etc.) or clustered (e.g. a large body of water).
You will need to tell MapReader which labels to select and then get the context for each of the relevant patches in order to work out if it is isolated or part of a line/cluster.

For example, if you want to post-process patches which are predicted to be 'railspace' or 'railspace&building', you would do the following:

.. code-block:: python

labels=["railspace", "railspace&building"]
patches.get_context(labels=labels)


.. note:: In the above example, we needed to use both 'railspace' and 'railspace&building' as our labels, since the continuous feature we are trying to post-process is railway lines (included in both these labels).

You will also need to tell MapReader how to update the label of each patch that is isolated and therefore likely to be a false positive.
This is done using the ``remap`` argument, which takes a dictionary of the form ``{old_label: new_label}``.

For example, if you want to remap all isolated 'railspace' patches to be labelled as 'no', and all isolated 'railspace&building' patches to be labelled as 'building', you would do the following:

.. code-block:: python

remap={"railspace": "no", "railspace&building": "building"}
patches.update_preds(remap=remap)

By default, only patches with model confidence of below 0.7 will be relabelled.
You can adjust this by passing the ``conf`` argument.

e.g. to relabel all isolated patches with confidence below 0.9, you would do the following:

.. code-block:: python

remap={"railspace": "no", "railspace&building": "building"}
patches.update_preds(remap=remap, conf=0.9)

Instead of relabelling your chosen patches to an existing label, you can also choose to relabel them to a new label.
For example, to mark them as 'false_positive', you would do the following:

.. code-block:: python

remap={"railspace": "false_positive", "railspace&building": "false_positive"}
patches.update_preds(remap=remap)


By default, after running `update_preds`, a new column will be added to your ``patches`` DataFrame called "new_predicted_label".
This will contain the updated predictions (or NaN if the patch was not relabelled).

Alternatively, to save the updated predictions inplace you can pass the ``inplace`` argument:

.. code-block:: python

remap={"railspace": "no", "railspace&building": "building"}
patches.update_preds(remap=remap, inplace=True)


Finally, to save your outputs to a csv file, you can do the following:

.. code-block:: python

patches.to_csv("path/to/save/updated_predictions.csv")
39 changes: 29 additions & 10 deletions mapreader/classify/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,15 @@ def __init__(
raise ValueError(
"[ERROR] ``labels_map`` and ``load_path`` cannot be used together - please set one to ``None``."
)

# load object
self.load(load_path=load_path, force_device=force_device)

# add any extra dataloaders
if dataloaders:
for set_name, dataloader in dataloaders.items():
self.dataloaders[set_name]=dataloader
self.dataloaders[set_name] = dataloader

else:
if model is None or labels_map is None:
raise ValueError(
Expand All @@ -144,7 +144,7 @@ def __init__(

self.labels_map = labels_map

# set up model and move to device
# set up model and move to device
print("[INFO] Initializing model.")
if isinstance(model, nn.Module):
self.model = model.to(self.device)
Expand Down Expand Up @@ -174,11 +174,9 @@ def __init__(

# add dataloaders and labels_map
self.dataloaders = dataloaders if dataloaders else {}

for set_name, dataloader in self.dataloaders.items():
print(
f'[INFO] Loaded "{set_name}" with {len(dataloader.dataset)} items.'
)
print(f'[INFO] Loaded "{set_name}" with {len(dataloader.dataset)} items.')

def generate_layerwise_lrs(
self,
Expand Down Expand Up @@ -892,7 +890,7 @@ def train_core(
raise ValueError(
"[ERROR] Criterion is not yet defined.\n\n\
Use ``add_criterion`` to define one."
)
)

if self.is_inception and (
phase.lower() in train_phase_names
Expand Down Expand Up @@ -1762,6 +1760,27 @@ def save(
os.path.join(par_name, f"model_state_dict_{base_name}"),
)

def save_predictions(
self,
set_name: str,
save_path: str | None = None,
delimiter: str = ",",
):
if set_name not in self.dataloaders.keys():
raise ValueError(
f"[ERROR] ``set_name`` must be one of {list(self.dataloaders.keys())}."
)

patch_df = self.dataloaders[set_name].dataset.patch_df
patch_df["predicted_label"] = self.pred_label
patch_df["pred"] = self.pred_label_indices
patch_df["conf"] = np.array(self.pred_conf).max(axis=1)

if save_path is None:
save_path = f"{set_name}_predictions_patch_df.csv"
patch_df.to_csv(save_path, sep=delimiter)
print(f"[INFO] Saved predictions to {save_path}.")

def load_dataset(
self,
dataset: PatchDataset,
Expand Down
182 changes: 182 additions & 0 deletions mapreader/process/post_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
#!/usr/bin/env python
from __future__ import annotations

from ast import literal_eval
from itertools import product

import pandas as pd
from tqdm import tqdm


class PatchDataFrame(pd.DataFrame):
"""A class for storing patch dataframes.

Parameters
----------
patch_df : pd.DataFrame
the DataFrame containing patches and predictions
labels_map : dict
the dictionary mapping label indices to their labels.
e.g. `{0: "no", 1: "railspace"}`.
"""

def __init__(
self,
patch_df: pd.DataFrame,
labels_map: dict,
):
required_columns = [
"parent_id",
"pixel_bounds",
"pred",
"predicted_label",
"conf",
]
if not all([col in patch_df.columns for col in required_columns]):
raise ValueError(
f"[ERROR] Your dataframe must contain the following columns: {required_columns}."
)

# ensure lists/tuples are evaluated as such
for col in patch_df.columns:
try:
patch_df[col] = patch_df[col].apply(literal_eval)
except (ValueError, TypeError, SyntaxError):
pass

if patch_df.index.name != "image_id" and "image_id" in patch_df.columns:
patch_df.set_index("image_id", drop=True, inplace=True)

if all(
[col in patch_df.columns for col in ["min_x", "min_y", "max_x", "max_y"]]
):
print(
"[INFO] Using existing pixel bounds columns (min_x, min_y, max_x, max_y)."
)
else:
patch_df[["min_x", "min_y", "max_x", "max_y"]] = [*patch_df["pixel_bounds"]]

super().__init__(patch_df)

self.labels_map = labels_map
self._label_patches = None
self.context = {}

def get_context(
self,
labels: str | list,
):
"""Get the context of the patches with the specified labels.

Parameters
----------
labels : str | list
The label(s) to get context for.
"""
if isinstance(labels, str):
labels = [labels]
self._label_patches = self[self["predicted_label"].isin(labels)]

for id in tqdm(self._label_patches.index):
if id not in self.context:
context_list = self._get_context_id(id)
# only add context if all surrounding patches are found
if len(context_list) == 9:
self.context[id] = context_list

def _get_context_id(
self,
id,
):
"""Get the context of the patch with the specified index."""
parent_id = self.loc[id, "parent_id"]
min_x = self.loc[id, "min_x"]
min_y = self.loc[id, "min_y"]
max_x = self.loc[id, "max_x"]
max_y = self.loc[id, "max_y"]

context_grid = [
*product(
[(self["min_x"], min_x), (min_x, max_x), (max_x, self["max_x"])],
[(self["min_y"], min_y), (min_y, max_y), (max_y, self["max_y"])],
)
]
# reshape to min_x, min_y, max_x, max_y
context_grid = [(x[0][0], x[1][0], x[0][1], x[1][1]) for x in context_grid]

context_list = [
self[
(self["min_x"] == context_loc[0])
& (self["min_y"] == context_loc[1])
& (self["max_x"] == context_loc[2])
& (self["max_y"] == context_loc[3])
& (self["parent_id"] == parent_id)
]
for context_loc in context_grid
]
context_list = [x.index[0] for x in context_list if len(x)]
return context_list

def update_preds(self, remap: dict, conf: float = 0.7, inplace: bool = False):
"""Update the predictions of the chosen patches based on their context.

Parameters
----------
remap : dict
A dictionary mapping the old labels to the new labels.
conf : float, optional
Patches with confidence scores below this value will be relabelled, by default 0.7.
inplace : bool, optional
Whether to relabel inplace or create new dataframe columns, by default False
"""
if self._label_patches is None:
raise ValueError("[ERROR] You must run `get_context` first.")
if len(self.context) == 0:
raise ValueError(
"[ERROR] No patches to update. Try changing which labels you are updating."
)

labels = self._label_patches["predicted_label"].unique()
if any([label not in remap.keys() for label in labels]):
raise ValueError(
f"[ERROR] You must specify a remap for each label in {labels}."
)

# add new label to labels_map if not already present (assume label index is next in sequence)
for new_label in remap.values():
if new_label not in self.labels_map.values():
print(
[
f"[INFO] Adding {new_label} to labels_map at index {len(self.labels_map)}."
]
)
self.labels_map[len(self.labels_map)] = new_label

for id in tqdm(self.context):
self._update_preds_id(
id, labels=labels, remap=remap, conf=conf, inplace=inplace
)

def _update_preds_id(
self, id, labels: str | list, remap: dict, conf: float, inplace: bool = False
):
"""Update the predictions of the patch with the specified index."""
context_list = self.context[id]

context_df = self[self.index.isin(context_list)]
# drop central patch from context
context_df.drop(index=id, inplace=True)

# reverse the labels_map dict
label_index_dict = {v: k for k, v in self.labels_map.items()}

prefix = "" if inplace else "new_"
if (not any(context_df["predicted_label"].isin(labels))) & (
rwood-97 marked this conversation as resolved.
Show resolved Hide resolved
self.loc[id, "conf"] < conf
):
self.loc[id, f"{prefix}predicted_label"] = remap[
self.loc[id, "predicted_label"]
]
self.loc[id, f"{prefix}pred"] = label_index_dict[
self.loc[id, f"{prefix}predicted_label"]
]
Loading
Loading