Skip to content

Commit

Permalink
Merge pull request #342 from Living-with-machines/339-postproc
Browse files Browse the repository at this point in the history
Add context-based post processing for linear features
  • Loading branch information
rwood-97 authored Feb 5, 2024
2 parents 6228aa2 + 777c857 commit 033917f
Show file tree
Hide file tree
Showing 5 changed files with 536 additions and 11 deletions.
92 changes: 91 additions & 1 deletion docs/source/User-guide/Post-process.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,94 @@
Post-process
=============

TBC
MapReader post-processing's sub-package currently contains one method for post-processing the predictions from your model based on the idea that features such as railways, roads, coastlines, etc. are continuous and so patches with these labels should be found near to other patches also with these labels.
For example, if a patch is predicted to be a railspace, but is surrounded by patches predicted to be non-railspace, then it is likely that the railspace patch is a false positive.

To implement this, for a given patch, the code checks whether any of the 8 surrounding patches have the same label (e.g. 'railspace') and, if not, assumes the current patch's predicted label to be a false positive.
The user can then choose how to relabel the patch (e.g. 'railspace' -> 'no').

To run the post-processing code, you will need to have saved the predictions from your model in the format expected for the post-processing code.
See the :doc:`/User-guide/Classify/Classify` docs for more on this.

If you have your predictions saved in a csv file, you will first need to load them into a pandas DataFrame:

.. code-block:: python
import pandas as pd
preds = pd.read_csv("path/to/predictions.csv", index_col=0)
You can then run the post-processing code as follows:

.. code-block:: python
from mapreader.process.post_process import PatchDataFrame
labels_map = {
0: "no",
1: "railspace",
2: "building",
3: "railspace&building"
}
patches = PatchDataFrame(preds, labels_map=labels_map)
MapReader's post-processing will only work for features that are expected be continuous (e.g. railway, road, coastline, etc.) or clustered (e.g. a large body of water).
You will need to tell MapReader which labels to select and then get the context for each of the relevant patches in order to work out if it is isolated or part of a line/cluster.

For example, if you want to post-process patches which are predicted to be 'railspace' or 'railspace&building', you would do the following:

.. code-block:: python
labels=["railspace", "railspace&building"]
patches.get_context(labels=labels)
.. note:: In the above example, we needed to use both 'railspace' and 'railspace&building' as our labels, since the continuous feature we are trying to post-process is railway lines (included in both these labels).

You will also need to tell MapReader how to update the label of each patch that is isolated and therefore likely to be a false positive.
This is done using the ``remap`` argument, which takes a dictionary of the form ``{old_label: new_label}``.

For example, if you want to remap all isolated 'railspace' patches to be labelled as 'no', and all isolated 'railspace&building' patches to be labelled as 'building', you would do the following:

.. code-block:: python
remap={"railspace": "no", "railspace&building": "building"}
patches.update_preds(remap=remap)
By default, only patches with model confidence of below 0.7 will be relabelled.
You can adjust this by passing the ``conf`` argument.

e.g. to relabel all isolated patches with confidence below 0.9, you would do the following:

.. code-block:: python
remap={"railspace": "no", "railspace&building": "building"}
patches.update_preds(remap=remap, conf=0.9)
Instead of relabelling your chosen patches to an existing label, you can also choose to relabel them to a new label.
For example, to mark them as 'false_positive', you would do the following:

.. code-block:: python
remap={"railspace": "false_positive", "railspace&building": "false_positive"}
patches.update_preds(remap=remap)
By default, after running `update_preds`, a new column will be added to your ``patches`` DataFrame called "new_predicted_label".
This will contain the updated predictions (or NaN if the patch was not relabelled).

Alternatively, to save the updated predictions inplace you can pass the ``inplace`` argument:

.. code-block:: python
remap={"railspace": "no", "railspace&building": "building"}
patches.update_preds(remap=remap, inplace=True)
Finally, to save your outputs to a csv file, you can do the following:

.. code-block:: python
patches.to_csv("path/to/save/updated_predictions.csv")
39 changes: 29 additions & 10 deletions mapreader/classify/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,15 @@ def __init__(
raise ValueError(
"[ERROR] ``labels_map`` and ``load_path`` cannot be used together - please set one to ``None``."
)

# load object
self.load(load_path=load_path, force_device=force_device)

# add any extra dataloaders
if dataloaders:
for set_name, dataloader in dataloaders.items():
self.dataloaders[set_name]=dataloader
self.dataloaders[set_name] = dataloader

else:
if model is None or labels_map is None:
raise ValueError(
Expand All @@ -144,7 +144,7 @@ def __init__(

self.labels_map = labels_map

# set up model and move to device
# set up model and move to device
print("[INFO] Initializing model.")
if isinstance(model, nn.Module):
self.model = model.to(self.device)
Expand Down Expand Up @@ -174,11 +174,9 @@ def __init__(

# add dataloaders and labels_map
self.dataloaders = dataloaders if dataloaders else {}

for set_name, dataloader in self.dataloaders.items():
print(
f'[INFO] Loaded "{set_name}" with {len(dataloader.dataset)} items.'
)
print(f'[INFO] Loaded "{set_name}" with {len(dataloader.dataset)} items.')

def generate_layerwise_lrs(
self,
Expand Down Expand Up @@ -892,7 +890,7 @@ def train_core(
raise ValueError(
"[ERROR] Criterion is not yet defined.\n\n\
Use ``add_criterion`` to define one."
)
)

if self.is_inception and (
phase.lower() in train_phase_names
Expand Down Expand Up @@ -1762,6 +1760,27 @@ def save(
os.path.join(par_name, f"model_state_dict_{base_name}"),
)

def save_predictions(
self,
set_name: str,
save_path: str | None = None,
delimiter: str = ",",
):
if set_name not in self.dataloaders.keys():
raise ValueError(
f"[ERROR] ``set_name`` must be one of {list(self.dataloaders.keys())}."
)

patch_df = self.dataloaders[set_name].dataset.patch_df
patch_df["predicted_label"] = self.pred_label
patch_df["pred"] = self.pred_label_indices
patch_df["conf"] = np.array(self.pred_conf).max(axis=1)

if save_path is None:
save_path = f"{set_name}_predictions_patch_df.csv"
patch_df.to_csv(save_path, sep=delimiter)
print(f"[INFO] Saved predictions to {save_path}.")

def load_dataset(
self,
dataset: PatchDataset,
Expand Down
182 changes: 182 additions & 0 deletions mapreader/process/post_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
#!/usr/bin/env python
from __future__ import annotations

from ast import literal_eval
from itertools import product

import pandas as pd
from tqdm import tqdm


class PatchDataFrame(pd.DataFrame):
"""A class for storing patch dataframes.
Parameters
----------
patch_df : pd.DataFrame
the DataFrame containing patches and predictions
labels_map : dict
the dictionary mapping label indices to their labels.
e.g. `{0: "no", 1: "railspace"}`.
"""

def __init__(
self,
patch_df: pd.DataFrame,
labels_map: dict,
):
required_columns = [
"parent_id",
"pixel_bounds",
"pred",
"predicted_label",
"conf",
]
if not all([col in patch_df.columns for col in required_columns]):
raise ValueError(
f"[ERROR] Your dataframe must contain the following columns: {required_columns}."
)

# ensure lists/tuples are evaluated as such
for col in patch_df.columns:
try:
patch_df[col] = patch_df[col].apply(literal_eval)
except (ValueError, TypeError, SyntaxError):
pass

if patch_df.index.name != "image_id" and "image_id" in patch_df.columns:
patch_df.set_index("image_id", drop=True, inplace=True)

if all(
[col in patch_df.columns for col in ["min_x", "min_y", "max_x", "max_y"]]
):
print(
"[INFO] Using existing pixel bounds columns (min_x, min_y, max_x, max_y)."
)
else:
patch_df[["min_x", "min_y", "max_x", "max_y"]] = [*patch_df["pixel_bounds"]]

super().__init__(patch_df)

self.labels_map = labels_map
self._label_patches = None
self.context = {}

def get_context(
self,
labels: str | list,
):
"""Get the context of the patches with the specified labels.
Parameters
----------
labels : str | list
The label(s) to get context for.
"""
if isinstance(labels, str):
labels = [labels]
self._label_patches = self[self["predicted_label"].isin(labels)]

for id in tqdm(self._label_patches.index):
if id not in self.context:
context_list = self._get_context_id(id)
# only add context if all surrounding patches are found
if len(context_list) == 9:
self.context[id] = context_list

def _get_context_id(
self,
id,
):
"""Get the context of the patch with the specified index."""
parent_id = self.loc[id, "parent_id"]
min_x = self.loc[id, "min_x"]
min_y = self.loc[id, "min_y"]
max_x = self.loc[id, "max_x"]
max_y = self.loc[id, "max_y"]

context_grid = [
*product(
[(self["min_x"], min_x), (min_x, max_x), (max_x, self["max_x"])],
[(self["min_y"], min_y), (min_y, max_y), (max_y, self["max_y"])],
)
]
# reshape to min_x, min_y, max_x, max_y
context_grid = [(x[0][0], x[1][0], x[0][1], x[1][1]) for x in context_grid]

context_list = [
self[
(self["min_x"] == context_loc[0])
& (self["min_y"] == context_loc[1])
& (self["max_x"] == context_loc[2])
& (self["max_y"] == context_loc[3])
& (self["parent_id"] == parent_id)
]
for context_loc in context_grid
]
context_list = [x.index[0] for x in context_list if len(x)]
return context_list

def update_preds(self, remap: dict, conf: float = 0.7, inplace: bool = False):
"""Update the predictions of the chosen patches based on their context.
Parameters
----------
remap : dict
A dictionary mapping the old labels to the new labels.
conf : float, optional
Patches with confidence scores below this value will be relabelled, by default 0.7.
inplace : bool, optional
Whether to relabel inplace or create new dataframe columns, by default False
"""
if self._label_patches is None:
raise ValueError("[ERROR] You must run `get_context` first.")
if len(self.context) == 0:
raise ValueError(
"[ERROR] No patches to update. Try changing which labels you are updating."
)

labels = self._label_patches["predicted_label"].unique()
if any([label not in remap.keys() for label in labels]):
raise ValueError(
f"[ERROR] You must specify a remap for each label in {labels}."
)

# add new label to labels_map if not already present (assume label index is next in sequence)
for new_label in remap.values():
if new_label not in self.labels_map.values():
print(
[
f"[INFO] Adding {new_label} to labels_map at index {len(self.labels_map)}."
]
)
self.labels_map[len(self.labels_map)] = new_label

for id in tqdm(self.context):
self._update_preds_id(
id, labels=labels, remap=remap, conf=conf, inplace=inplace
)

def _update_preds_id(
self, id, labels: str | list, remap: dict, conf: float, inplace: bool = False
):
"""Update the predictions of the patch with the specified index."""
context_list = self.context[id]

context_df = self[self.index.isin(context_list)]
# drop central patch from context
context_df.drop(index=id, inplace=True)

# reverse the labels_map dict
label_index_dict = {v: k for k, v in self.labels_map.items()}

prefix = "" if inplace else "new_"
if (not any(context_df["predicted_label"].isin(labels))) & (
self.loc[id, "conf"] < conf
):
self.loc[id, f"{prefix}predicted_label"] = remap[
self.loc[id, "predicted_label"]
]
self.loc[id, f"{prefix}pred"] = label_index_dict[
self.loc[id, f"{prefix}predicted_label"]
]
Loading

0 comments on commit 033917f

Please sign in to comment.