Skip to content

Commit

Permalink
Merge pull request #80 from p-lambda/dev
Browse files Browse the repository at this point in the history
v1.2.2
  • Loading branch information
ssagawa authored Aug 4, 2021
2 parents 1d06a18 + 88ba842 commit 061fd04
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 48 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pip install wilds
If you have already installed it, please check that you have the latest version:
```bash
python -c "import wilds; print(wilds.__version__)"
# This should print "1.2.1". If it doesn't, update by running:
# This should print "1.2.2". If it doesn't, update by running:
pip install -U wilds
```

Expand All @@ -50,7 +50,10 @@ pip install -e .
- torch>=1.7.0
- torch-scatter>=2.0.5
- torch-geometric>=1.6.1
- torchvision>=0.8.2
- tqdm>=4.53.0
- scikit-learn>=0.20.0
- scipy>=1.5.4

Running `pip install wilds` or `pip install -e .` will automatically check for and install all of these requirements
except for the `torch-scatter` and `torch-geometric` packages, which require a [quick manual install](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html#installation-via-binaries).
Expand All @@ -63,9 +66,8 @@ These scripts are not part of the installed WILDS package. To use them, you shou
git clone [email protected]:p-lambda/wilds.git
```

To run these scripts, you will need to install these additional dependencies:
To run these scripts, you will also need to install this additional dependency:

- torchvision>=0.8.2
- transformers>=3.5.0

All baseline experiments in the paper were run on Python 3.8.5 and CUDA 10.1.
Expand Down
36 changes: 25 additions & 11 deletions examples/configs/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from configs.algorithm import algorithm_defaults
from configs.model import model_defaults
from configs.scheduler import scheduler_defaults
Expand All @@ -7,41 +8,44 @@
def populate_defaults(config):
"""Populates hyperparameters with defaults implied by choices
of other hyperparameters."""

orig_config = copy.deepcopy(config)
assert config.dataset is not None, 'dataset must be specified'
assert config.algorithm is not None, 'algorithm must be specified'

# implied defaults from choice of dataset
config = populate_config(
config,
config,
dataset_defaults[config.dataset]
)

# implied defaults from choice of split
if config.dataset in split_defaults and config.split_scheme in split_defaults[config.dataset]:
config = populate_config(
config,
config,
split_defaults[config.dataset][config.split_scheme]
)

# implied defaults from choice of algorithm
config = populate_config(
config,
config,
algorithm_defaults[config.algorithm]
)

# implied defaults from choice of loader
config = populate_config(
config,
config,
loader_defaults
)
# implied defaults from choice of model
if config.model: config = populate_config(
config,
config,
model_defaults[config.model],
)

# implied defaults from choice of scheduler
if config.scheduler: config = populate_config(
config,
config,
scheduler_defaults[config.scheduler]
)

Expand All @@ -52,12 +56,22 @@ def populate_defaults(config):

# basic checks
required_fields = [
'split_scheme', 'train_loader', 'uniform_over_groups', 'batch_size', 'eval_loader', 'model', 'loss_function',
'split_scheme', 'train_loader', 'uniform_over_groups', 'batch_size', 'eval_loader', 'model', 'loss_function',
'val_metric', 'val_metric_decreasing', 'n_epochs', 'optimizer', 'lr', 'weight_decay',
]
]
for field in required_fields:
assert getattr(config, field) is not None, f"Must manually specify {field} for this setup."

# data loader validations
# we only raise this error if the train_loader is standard, and
# n_groups_per_batch or distinct_groups are
# specified by the user (instead of populated as a default)
if config.train_loader == 'standard':
if orig_config.n_groups_per_batch is not None:
raise ValueError("n_groups_per_batch cannot be specified if the data loader is 'standard'. Consider using a 'group' data loader instead.")
if orig_config.distinct_groups is not None:
raise ValueError("distinct_groups cannot be specified if the data loader is 'standard'. Consider using a 'group' data loader instead.")

return config

def populate_config(config, template: dict, force_compatibility=False):
Expand All @@ -78,7 +92,7 @@ def populate_config(config, template: dict, force_compatibility=False):
d_config[key] = val
elif d_config[key] != val and force_compatibility:
raise ValueError(f"Argument {key} must be set to {val}")

else: # config[key] expected to be a kwarg dict
for kwargs_key, kwargs_val in val.items():
if kwargs_key not in d_config[key] or d_config[key][kwargs_key] is None:
Expand Down
9 changes: 4 additions & 5 deletions examples/models/detection/fasterrcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from torchvision.models.utils import load_state_dict_from_url
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops import MultiScaleRoIAlign
from torchvision.models.detection import _utils as det_utils
from torchvision.models.detection.anchor_utils import AnchorGenerator
from torchvision.models.detection.generalized_rcnn import GeneralizedRCNN
from torchvision.models.detection.faster_rcnn import TwoMLPHead
Expand Down Expand Up @@ -127,11 +126,11 @@ def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets)
sampled_neg_inds = torch.where(torch.cat(sampled_neg_inds, dim=0))[0]
sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)

box_loss.append(det_utils.smooth_l1_loss(
box_loss.append(F.smooth_l1_loss(
pred_bbox_deltas_[sampled_pos_inds],
regression_targets_[sampled_pos_inds],
beta=1 / 9,
size_average=False,
reduction='sum',
) / (sampled_inds.numel()))

objectness_loss.append(F.binary_cross_entropy_with_logits(
Expand Down Expand Up @@ -226,11 +225,11 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):

box_regression_ = box_regression_.reshape(N, -1, 4)

box_loss_ = det_utils.smooth_l1_loss(
box_loss_ = F.smooth_l1_loss(
box_regression_[sampled_pos_inds_subset, labels_pos],
regression_targets_[sampled_pos_inds_subset],
beta=1 / 9,
size_average=False,
reduction='sum',
)
box_loss.append(box_loss_ / labels_.numel())

Expand Down
5 changes: 3 additions & 2 deletions examples/run_expt.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,9 @@ def main():
split_scheme=config.split_scheme,
**config.dataset_kwargs)

# To implement data augmentation (i.e., have different transforms
# at training time vs. test time), modify these two lines:
# To modify data augmentation, modify the following code block.
# If you want to use transforms that modify both `x` and `y`,
# set `do_transform_y` to True when initializing the `WILDSSubset` below.
train_transform = initialize_transform(
transform_name=config.transform,
config=config,
Expand Down
20 changes: 8 additions & 12 deletions examples/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@

def initialize_transform(transform_name, config, dataset, is_training):
"""
Transforms should take in a single (x, y)
and return (transformed_x, transformed_y).
By default, transforms should take in `x` and return `transformed_x`.
For transforms that take in `(x, y)` and return `(transformed_x, transformed_y)`,
set `do_transform_y` to True when initializing the WILDSSubset.
"""
if transform_name is None:
return None
Expand All @@ -25,11 +26,6 @@ def initialize_transform(transform_name, config, dataset, is_training):
else:
raise ValueError(f"{transform_name} not recognized")

def transform_input_only(input_transform):
def transform(x, y):
return input_transform(x), y
return transform

def initialize_bert_transform(config):
assert 'bert' in config.model
assert config.max_token_length is not None
Expand All @@ -55,7 +51,7 @@ def transform(text):
dim=2)
x = torch.squeeze(x, dim=0) # First shape dim is always 1
return x
return transform_input_only(transform)
return transform

def getBertTokenizer(model):
if model == 'bert-base-uncased':
Expand All @@ -79,7 +75,7 @@ def initialize_image_base_transform(config, dataset):
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]
transform = transforms.Compose(transform_steps)
return transform_input_only(transform)
return transform

def initialize_image_resize_and_center_crop_transform(config, dataset):
"""
Expand All @@ -98,7 +94,7 @@ def initialize_image_resize_and_center_crop_transform(config, dataset):
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
return transform_input_only(transform)
return transform

def initialize_poverty_transform(is_training):
if is_training:
Expand All @@ -115,7 +111,7 @@ def transform_rgb(img):
img[:3] = rgb_transform(img[:3][[2,1,0]])[[2,1,0]]
return img
transform = transforms.Lambda(lambda x: transform_rgb(x))
return transform_input_only(transform)
return transform
else:
return None

Expand Down Expand Up @@ -148,4 +144,4 @@ def random_rotation(x: torch.Tensor) -> torch.Tensor:
t_standardize,
]
transform = transforms.Compose(transforms_ls)
return transform_input_only(transform)
return transform
10 changes: 6 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@
long_description_content_type="text/markdown",
install_requires = [
'numpy>=1.19.1',
'ogb>=1.2.6',
'outdated>=0.2.0',
'pandas>=1.1.0',
'scikit-learn>=0.20.0',
'pillow>=7.2.0',
'pytz>=2020.4',
'torch>=1.7.0',
'ogb>=1.2.6',
'torchvision>=0.8.2',
'tqdm>=4.53.0',
'outdated>=0.2.0',
'pytz>=2020.4',
'scikit-learn>=0.20.0',
'scipy>=1.5.4'
],
license='MIT',
packages=setuptools.find_packages(exclude=['dataset_preprocessing', 'examples', 'examples.models', 'examples.models.bert']),
Expand Down
15 changes: 10 additions & 5 deletions wilds/common/metrics/all_metrics.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.ops.boxes import box_iou
from torchvision.models.detection._utils import Matcher
from torchvision.ops import nms, box_convert
import numpy as np
import torch.nn.functional as F
from wilds.common.metrics.metric import Metric, ElementwiseMetric, MultiTaskMetric
from wilds.common.metrics.loss import ElementwiseLoss
from wilds.common.utils import avg_over_groups, minimum, maximum, get_counts
Expand Down Expand Up @@ -243,12 +243,17 @@ def _accuracy(self, src_boxes,pred_boxes , iou_threshold):
total_pred = len(pred_boxes)
if total_gt > 0 and total_pred > 0:
# Define the matcher and distance matrix based on iou
matcher = Matcher(iou_threshold,iou_threshold,allow_low_quality_matches=False)
match_quality_matrix = box_iou(src_boxes,pred_boxes)
matcher = Matcher(
iou_threshold,
iou_threshold,
allow_low_quality_matches=False)
match_quality_matrix = box_iou(
src_boxes,
pred_boxes)
results = matcher(match_quality_matrix)
true_positive = torch.count_nonzero(results.unique() != -1)
matched_elements = results[results > -1]
#in Matcher, a pred element can be matched only twice
# in Matcher, a pred element can be matched only twice
false_positive = (
torch.count_nonzero(results == -1) +
(len(matched_elements) - len(matched_elements.unique()))
Expand Down
19 changes: 14 additions & 5 deletions wilds/datasets/wilds_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,11 +433,16 @@ def standard_group_eval(metric, grouper, y_pred, y_true, metadata, aggregate=Tru


class WILDSSubset(WILDSDataset):
def __init__(self, dataset, indices, transform):
def __init__(self, dataset, indices, transform, do_transform_y=False):
"""
This acts like torch.utils.data.Subset, but on WILDSDatasets.
We pass in transform explicitly because it can potentially vary at
training vs. test time, if we're using data augmentation.
This acts like `torch.utils.data.Subset`, but on `WILDSDatasets`.
We pass in `transform` (which is used for data augmentation) explicitly
because it can potentially vary on the training vs. test subsets.
`do_transform_y` (bool): When this is false (the default),
`self.transform ` acts only on `x`.
Set this to true if `self.transform` should
operate on `(x,y)` instead of just `x`.
"""
self.dataset = dataset
self.indices = indices
Expand All @@ -449,11 +454,15 @@ def __init__(self, dataset, indices, transform):
if hasattr(dataset, attr_name):
setattr(self, attr_name, getattr(dataset, attr_name))
self.transform = transform
self.do_transform_y = do_transform_y

def __getitem__(self, idx):
x, y, metadata = self.dataset[self.indices[idx]]
if self.transform is not None:
x, y = self.transform(x, y)
if self.do_transform_y:
x, y = self.transform(x, y)
else:
x = self.transform(x)
return x, y, metadata

def __len__(self):
Expand Down
2 changes: 1 addition & 1 deletion wilds/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from threading import Thread

__version__ = '1.2.1'
__version__ = '1.2.2'

try:
os.environ['OUTDATED_IGNORE'] = '1'
Expand Down

0 comments on commit 061fd04

Please sign in to comment.