Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 1, 2024
1 parent 2dcd599 commit 24d5c32
Show file tree
Hide file tree
Showing 43 changed files with 74 additions and 22 deletions.
1 change: 1 addition & 0 deletions src/pl_bolts/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Collection of PyTorchLightning callbacks."""

from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate
from pl_bolts.callbacks.data_monitor import ModuleDataMonitor, TrainingDataMonitor
from pl_bolts.callbacks.printing import PrintTableMetricsCallback
Expand Down
2 changes: 2 additions & 0 deletions src/pl_bolts/callbacks/data_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def log_histograms(self, batch: Any, group: str = "") -> None:
Otherwise the histograms get labelled with an integer index.
Each label also has the tensors's shape as suffix.
group: Name under which the histograms will be grouped.
"""
if not self._log or (self._train_batch_idx + 1) % self._log_every_n_steps != 0: # type: ignore[operator]
return
Expand Down Expand Up @@ -220,6 +221,7 @@ def __init__(self, log_every_n_steps: int = None) -> None:
# log histogram of training data passed to `LightningModule.training_step`
trainer = Trainer(callbacks=[TrainingDataMonitor()])
"""
super().__init__(log_every_n_steps=log_every_n_steps)

Expand Down
6 changes: 4 additions & 2 deletions src/pl_bolts/callbacks/verification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def _model_forward(self, input_array: Any) -> Any:
Returns:
The output of the model.
"""
if isinstance(input_array, tuple):
return self.model(*input_array)
Expand Down Expand Up @@ -105,15 +106,16 @@ def __init__(self, warn: bool = True, error: bool = False) -> None:
self._raise_error = error

def message(self, *args: Any, **kwargs: Any) -> str:
"""The message to be printed when the model does not pass the verification. If the message for warning and
error differ, override the :meth:`warning_message` and :meth:`error_message` methods directly.
"""The message to be printed when the model does not pass the verification. If the message for warning and error
differ, override the :meth:`warning_message` and :meth:`error_message` methods directly.
Arguments:
*args: Any positional arguments that are needed to construct the message.
**kwargs: Any keyword arguments that are needed to construct the message.
Returns:
The message as a string.
"""

def warning_message(self, *args: Any, **kwargs: Any) -> str:
Expand Down
6 changes: 4 additions & 2 deletions src/pl_bolts/callbacks/verification/batch_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class BatchGradientVerificationCallback(VerificationCallbackBase):
"""The callback version of the :class:`BatchGradientVerification` test.
Verification is performed right before training begins.
"""

def __init__(
Expand Down Expand Up @@ -211,12 +212,13 @@ def collect_batches(tensor: Tensor) -> Tensor:
@under_review()
@contextmanager
def selective_eval(model: nn.Module, layer_types: Iterable[Type[nn.Module]]) -> None:
"""A context manager that sets all requested types of layers to eval mode. This method uses an ``isinstance``
check, so all subclasses are also affected.
"""A context manager that sets all requested types of layers to eval mode. This method uses an ``isinstance`` check,
so all subclasses are also affected.
Args:
model: A model which has layers that need to be set to eval mode.
layer_types: The list of class objects for which all layers of that type will be set to eval mode.
"""
to_revert = []
try:
Expand Down
1 change: 1 addition & 0 deletions src/pl_bolts/callbacks/vision/image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class TensorboardGenerativeModelImageSampler(Callback):
from pl_bolts.callbacks import TensorboardGenerativeModelImageSampler
trainer = Trainer(callbacks=[TensorboardGenerativeModelImageSampler()])
"""

def __init__(
Expand Down
5 changes: 3 additions & 2 deletions src/pl_bolts/callbacks/vision/sr_image_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

@under_review()
class SRImageLoggerCallback(Callback):
"""Logs low-res, generated high-res, and ground truth high-res images to TensorBoard Your model must implement
the ``forward`` function for generation.
"""Logs low-res, generated high-res, and ground truth high-res images to TensorBoard Your model must implement the
``forward`` function for generation.
Requirements::
Expand All @@ -30,6 +30,7 @@ class SRImageLoggerCallback(Callback):
from pl_bolts.callbacks import SRImageLoggerCallback
trainer = Trainer(callbacks=[SRImageLoggerCallback()])
"""

def __init__(self, log_interval: int = 1000, scale_factor: int = 4, num_samples: int = 5) -> None:
Expand Down
1 change: 1 addition & 0 deletions src/pl_bolts/datamodules/experience_source.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Datamodules for RL models that rely on experiences generated during training Based on implementations found
here: https://github.com/Shmuma/ptan/blob/master/ptan/experience.py."""

from abc import ABC
from collections import deque, namedtuple
from typing import Callable, Iterator, List, Tuple
Expand Down
4 changes: 2 additions & 2 deletions src/pl_bolts/datamodules/imagenet_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,12 @@ def train_dataloader(self) -> DataLoader:
return loader

def val_dataloader(self) -> DataLoader:
"""Uses the part of the train split of imagenet2012 that was not used for training via
`num_imgs_per_val_class`
"""Uses the part of the train split of imagenet2012 that was not used for training via `num_imgs_per_val_class`
Args:
batch_size: the batch size
transforms: the transforms
"""
transforms = self.val_transform() if self.val_transforms is None else self.val_transforms

Expand Down
1 change: 1 addition & 0 deletions src/pl_bolts/datamodules/stl10_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def train_dataloader_mixed(self) -> DataLoader:
batch_size: the batch size
transforms: a sequence of transforms
"""
transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms

Expand Down
2 changes: 2 additions & 0 deletions src/pl_bolts/datamodules/vocdetection_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def train_dataloader(self, image_transforms: Optional[Callable] = None) -> DataL
Args:
image_transforms: custom image-only transforms
"""
transforms = [
_prepare_voc_instance,
Expand All @@ -181,6 +182,7 @@ def val_dataloader(self, image_transforms: Optional[Callable] = None) -> DataLoa
Args:
image_transforms: custom image-only transforms
"""
transforms = [
_prepare_voc_instance,
Expand Down
1 change: 1 addition & 0 deletions src/pl_bolts/datasets/kitti_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def encode_segmap(self, mask):
It also sets all of the valid pixels to the appropriate value between 0 and `len(valid_labels)` (the number of
valid classes), so it can be used properly by the loss function when comparing with the output.
"""
for voidc in self.void_labels:
mask[mask == voidc] = self.ignore_index
Expand Down
1 change: 1 addition & 0 deletions src/pl_bolts/datasets/sr_dataset_mixin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Adapted from: https://github.com/https-deeplearning-ai/GANs-Public."""

from typing import Any, Tuple

import torch
Expand Down
1 change: 1 addition & 0 deletions src/pl_bolts/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def to_tensor(arrays: TArrays) -> torch.Tensor:
Returns:
Tensor of the integers
"""
return torch.tensor(arrays)

Expand Down
1 change: 1 addition & 0 deletions src/pl_bolts/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Collection of PyTorchLightning models."""

from pl_bolts.models.autoencoders.basic_ae.basic_ae_module import AE
from pl_bolts.models.autoencoders.basic_vae.basic_vae_module import VAE
from pl_bolts.models.mnist_module import LitMNIST
Expand Down
6 changes: 4 additions & 2 deletions src/pl_bolts/models/detection/yolo/darknet_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def read(tensor: Tensor) -> int:
"""Reads the contents of ``tensor`` from the current position of ``weight_file``.
Returns the number of elements read. If there's no more data in ``weight_file``, returns 0.
"""
np_array = np.fromfile(weight_file, count=tensor.numel(), dtype=np.float32)
num_elements = np_array.size
Expand Down Expand Up @@ -275,8 +276,8 @@ def convert(key: str, value: str) -> Union[str, int, float, List[Union[str, int,


def _create_layer(config: CONFIG, num_inputs: List[int], **kwargs: Any) -> CREATE_LAYER_OUTPUT:
"""Calls one of the ``_create_<layertype>(config, num_inputs)`` functions to create a PyTorch module from the
layer config.
"""Calls one of the ``_create_<layertype>(config, num_inputs)`` functions to create a PyTorch module from the layer
config.
Args:
config: Dictionary of configuration options for this layer.
Expand All @@ -285,6 +286,7 @@ def _create_layer(config: CONFIG, num_inputs: List[int], **kwargs: Any) -> CREAT
Returns:
module (:class:`~torch.nn.Module`), num_outputs (int): The created PyTorch module and the number of channels in
its output.
"""
create_func: Dict[str, Callable[..., CREATE_LAYER_OUTPUT]] = {
"convolutional": _create_convolutional,
Expand Down
1 change: 1 addition & 0 deletions src/pl_bolts/models/detection/yolo/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def _target_labels_to_probs(
Returns:
An ``[M, C]`` matrix of target class probabilities.
"""
if targets.ndim == 1:
# The data may contain a different number of classes than what the model predicts. In case a label is
Expand Down
12 changes: 8 additions & 4 deletions src/pl_bolts/models/detection/yolo/torch_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def run_detection(
detections: A list where a tensor containing the detections will be appended to.
losses: A list where a tensor containing the losses will be appended to, if ``targets`` is given.
hits: A list where the number of targets that matched this layer will be appended to, if ``targets`` is given.
"""
output, preds = detection_layer(layer_input, image_size)
detections.append(output)
Expand Down Expand Up @@ -69,6 +70,7 @@ def run_detection_with_aux_head(
detections: A list where a tensor containing the detections will be appended to.
losses: A list where a tensor containing the losses will be appended to, if ``targets`` is given.
hits: A list where the number of targets that matched this layer will be appended to, if ``targets`` is given.
"""
output, preds = detection_layer(layer_input, image_size)
detections.append(output)
Expand Down Expand Up @@ -1132,8 +1134,8 @@ def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPU


class YOLOV5Network(nn.Module):
"""The YOLOv5 network architecture. Different variants (n/s/m/l/x) can be achieved by adjusting the ``depth``
and ``width`` parameters.
"""The YOLOv5 network architecture. Different variants (n/s/m/l/x) can be achieved by adjusting the ``depth`` and
``width`` parameters.
Args:
num_classes: Number of different classes that this model predicts.
Expand Down Expand Up @@ -1176,6 +1178,7 @@ class YOLOV5Network(nn.Module):
class_loss_multiplier: Classification loss will be scaled by this value.
xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps
to produce coordinate values close to one.
"""

def __init__(
Expand Down Expand Up @@ -1613,8 +1616,8 @@ def forward(self, x: Tensor) -> Tensor:


class YOLOXNetwork(nn.Module):
"""The YOLOX network architecture. Different variants (nano/tiny/s/m/l/x) can be achieved by adjusting the
``depth`` and ``width`` parameters.
"""The YOLOX network architecture. Different variants (nano/tiny/s/m/l/x) can be achieved by adjusting the ``depth``
and ``width`` parameters.
Args:
num_classes: Number of different classes that this model predicts.
Expand Down Expand Up @@ -1657,6 +1660,7 @@ class YOLOXNetwork(nn.Module):
class_loss_multiplier: Classification loss will be scaled by this value.
xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps
to produce coordinate values close to one.
"""

def __init__(
Expand Down
5 changes: 3 additions & 2 deletions src/pl_bolts/models/detection/yolo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def aligned_iou(wh1: Tensor, wh2: Tensor) -> Tensor:


def iou_below(pred_boxes: Tensor, target_boxes: Tensor, threshold: float) -> Tensor:
"""Creates a binary mask whose value will be ``True``, unless the predicted box overlaps any target
significantly (IoU greater than ``threshold``).
"""Creates a binary mask whose value will be ``True``, unless the predicted box overlaps any target significantly
(IoU greater than ``threshold``).
Args:
pred_boxes: The predicted corner coordinates. Tensor of size ``[height, width, boxes_per_cell, 4]``.
Expand All @@ -112,6 +112,7 @@ def iou_below(pred_boxes: Tensor, target_boxes: Tensor, threshold: float) -> Ten
Returns:
A boolean tensor sized ``[height, width, boxes_per_cell]``, with ``False`` where the predicted box overlaps a
target significantly and ``True`` elsewhere.
"""
shape = pred_boxes.shape[:-1]
pred_boxes = pred_boxes.view(-1, 4)
Expand Down
12 changes: 8 additions & 4 deletions src/pl_bolts/models/detection/yolo/yolo_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ def __init__(
def forward(
self, images: Union[Tensor, IMAGES], targets: Optional[TARGETS] = None
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""Runs a forward pass through the network (all layers listed in ``self.network``), and if training targets
are provided, computes the losses from the detection layers.
"""Runs a forward pass through the network (all layers listed in ``self.network``), and if training targets are
provided, computes the losses from the detection layers.
Detections are concatenated from the detection layers. Each detection layer will produce a number of detections
that depends on the size of the feature map and the number of anchors per feature map cell.
Expand All @@ -161,6 +161,7 @@ def forward(
provided, a dictionary of losses. Detections are shaped ``[batch_size, anchors, classes + 5]``, where
``anchors`` is the feature map size (width * height) times the number of anchors per cell. The predicted box
coordinates are in `(x1, y1, x2, y2)` format and scaled to the input image size.
"""
self.validate_batch(images, targets)
images_tensor = images if isinstance(images, Tensor) else torch.stack(images)
Expand All @@ -185,6 +186,7 @@ def configure_optimizers(self) -> Tuple[List[optim.Optimizer], List[LRScheduler]
If weight decay is specified, it will be applied only to convolutional layer weights, as they contain much more
parameters than the biases and batch normalization parameters. Regularizing all parameters could lead to
underfitting.
"""
if ("weight_decay" in self.optimizer_params) and (self.optimizer_params["weight_decay"] != 0):
defaults = copy(self.optimizer_params)
Expand Down Expand Up @@ -574,12 +576,13 @@ def __init__(


class ResizedVOCDetectionDataModule(VOCDetectionDataModule):
"""A subclass of ``VOCDetectionDataModule`` that resizes the images to a specific size. YOLO expectes the image
size to be divisible by the ratio in which the network downsamples the image.
"""A subclass of ``VOCDetectionDataModule`` that resizes the images to a specific size. YOLO expectes the image size
to be divisible by the ratio in which the network downsamples the image.
Args:
width: Resize images to this width.
height: Resize images to this height.
"""

def __init__(self, width: int = 608, height: int = 608, **kwargs: Any):
Expand Down Expand Up @@ -609,6 +612,7 @@ def _resize(self, image: Tensor, target: TARGET) -> Tuple[Tensor, TARGET]:
Returns:
Resized image tensor.
"""
device = target["boxes"].device
height, width = image.shape[-2:]
Expand Down
1 change: 1 addition & 0 deletions src/pl_bolts/models/gans/srgan/components.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Adapted from: https://github.com/https-deeplearning-ai/GANs-Public."""

import torch
import torch.nn as nn

Expand Down
1 change: 1 addition & 0 deletions src/pl_bolts/models/gans/srgan/srgan_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Adapted from: https://github.com/https-deeplearning-ai/GANs-Public."""

from argparse import ArgumentParser
from pathlib import Path
from typing import Any, List, Optional, Tuple
Expand Down
1 change: 1 addition & 0 deletions src/pl_bolts/models/gans/srgan/srresnet_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Adapted from: https://github.com/https-deeplearning-ai/GANs-Public."""

from argparse import ArgumentParser
from typing import Any, Tuple

Expand Down
1 change: 1 addition & 0 deletions src/pl_bolts/models/rl/advantage_actor_critic_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Advantage Actor Critic (A2C)"""

from argparse import ArgumentParser
from collections import OrderedDict
from typing import Any, Iterator, List, Tuple
Expand Down
1 change: 1 addition & 0 deletions src/pl_bolts/models/rl/common/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
https://github.com/Shmuma/ptan/blob/master/ptan/agent.py.
"""

from abc import ABC
from typing import List

Expand Down
1 change: 1 addition & 0 deletions src/pl_bolts/models/rl/common/distributions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Distributions used in some continuous RL algorithms."""

import torch

from pl_bolts.utils.stability import under_review
Expand Down
1 change: 1 addition & 0 deletions src/pl_bolts/models/rl/common/gym_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Set of wrapper functions for gym environments taken from
https://github.com/Shmuma/ptan/blob/master/ptan/common/wrappers.py."""

import collections

import numpy as np
Expand Down
1 change: 1 addition & 0 deletions src/pl_bolts/models/rl/common/networks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Series of networks used Based on implementations found here:"""

import math
from typing import Tuple

Expand Down
1 change: 1 addition & 0 deletions src/pl_bolts/models/rl/double_dqn_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Double DQN."""

import argparse
from collections import OrderedDict
from typing import Tuple
Expand Down
1 change: 1 addition & 0 deletions src/pl_bolts/models/rl/dqn_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Deep Q Network."""

import argparse
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple
Expand Down
Loading

0 comments on commit 24d5c32

Please sign in to comment.