From 795d92482b1cfb99f4c86d79064b7066cf71d4e6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 May 2024 20:36:29 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pl_bolts/callbacks/__init__.py | 1 + src/pl_bolts/callbacks/data_monitor.py | 4 +++- src/pl_bolts/callbacks/verification/base.py | 6 ++++-- .../callbacks/verification/batch_gradient.py | 6 ++++-- src/pl_bolts/callbacks/vision/image_generation.py | 1 + src/pl_bolts/callbacks/vision/sr_image_logger.py | 5 +++-- src/pl_bolts/datamodules/experience_source.py | 1 + src/pl_bolts/datamodules/imagenet_datamodule.py | 4 ++-- src/pl_bolts/datamodules/stl10_datamodule.py | 1 + src/pl_bolts/datamodules/vocdetection_datamodule.py | 2 ++ src/pl_bolts/datasets/kitti_dataset.py | 1 + src/pl_bolts/datasets/sr_dataset_mixin.py | 1 + src/pl_bolts/datasets/utils.py | 1 + src/pl_bolts/models/__init__.py | 1 + .../models/detection/yolo/darknet_network.py | 6 ++++-- src/pl_bolts/models/detection/yolo/loss.py | 1 + src/pl_bolts/models/detection/yolo/torch_networks.py | 12 ++++++++---- src/pl_bolts/models/detection/yolo/utils.py | 5 +++-- src/pl_bolts/models/detection/yolo/yolo_module.py | 12 ++++++++---- src/pl_bolts/models/gans/srgan/components.py | 1 + src/pl_bolts/models/gans/srgan/srgan_module.py | 1 + src/pl_bolts/models/gans/srgan/srresnet_module.py | 1 + .../models/rl/advantage_actor_critic_model.py | 1 + src/pl_bolts/models/rl/common/agents.py | 1 + src/pl_bolts/models/rl/common/distributions.py | 1 + src/pl_bolts/models/rl/common/gym_wrappers.py | 1 + src/pl_bolts/models/rl/common/networks.py | 1 + src/pl_bolts/models/rl/double_dqn_model.py | 1 + src/pl_bolts/models/rl/dqn_model.py | 1 + src/pl_bolts/models/rl/dueling_dqn_model.py | 1 + src/pl_bolts/models/rl/noisy_dqn_model.py | 1 + src/pl_bolts/models/rl/per_dqn_model.py | 1 + src/pl_bolts/models/rl/ppo_model.py | 3 +-- src/pl_bolts/models/rl/sac_model.py | 1 + src/pl_bolts/models/self_supervised/__init__.py | 1 + .../models/self_supervised/cpc/cpc_module.py | 1 + .../models/self_supervised/moco/moco_module.py | 2 ++ src/pl_bolts/models/self_supervised/moco/utils.py | 1 + .../models/self_supervised/swav/swav_module.py | 1 + .../models/self_supervised/swav/swav_resnet.py | 1 + src/pl_bolts/models/vision/pixel_cnn.py | 1 + src/pl_bolts/optimizers/lars.py | 1 + tests/models/rl/unit/test_agents.py | 1 + 43 files changed, 75 insertions(+), 23 deletions(-) diff --git a/src/pl_bolts/callbacks/__init__.py b/src/pl_bolts/callbacks/__init__.py index 2225372f48..ce72e9dd10 100644 --- a/src/pl_bolts/callbacks/__init__.py +++ b/src/pl_bolts/callbacks/__init__.py @@ -1,4 +1,5 @@ """Collection of PyTorchLightning callbacks.""" + from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate from pl_bolts.callbacks.data_monitor import ModuleDataMonitor, TrainingDataMonitor from pl_bolts.callbacks.printing import PrintTableMetricsCallback diff --git a/src/pl_bolts/callbacks/data_monitor.py b/src/pl_bolts/callbacks/data_monitor.py index 7a39a1e709..0058e07c28 100644 --- a/src/pl_bolts/callbacks/data_monitor.py +++ b/src/pl_bolts/callbacks/data_monitor.py @@ -73,6 +73,7 @@ def log_histograms(self, batch: Any, group: str = "") -> None: Otherwise the histograms get labelled with an integer index. Each label also has the tensors's shape as suffix. group: Name under which the histograms will be grouped. + """ if not self._log or (self._train_batch_idx + 1) % self._log_every_n_steps != 0: # type: ignore[operator] return @@ -112,7 +113,7 @@ def _is_logger_available(self, logger: Logger) -> bool: if not isinstance(logger, self.supported_loggers): rank_zero_warn( f"{self.__class__.__name__} does not support logging with {logger.__class__.__name__}." - f" Supported loggers are: {', '.join((str(x.__name__) for x in self.supported_loggers))}" + f" Supported loggers are: {', '.join(str(x.__name__) for x in self.supported_loggers)}" ) available = False return available @@ -220,6 +221,7 @@ def __init__(self, log_every_n_steps: int = None) -> None: # log histogram of training data passed to `LightningModule.training_step` trainer = Trainer(callbacks=[TrainingDataMonitor()]) + """ super().__init__(log_every_n_steps=log_every_n_steps) diff --git a/src/pl_bolts/callbacks/verification/base.py b/src/pl_bolts/callbacks/verification/base.py index 49e2e3593a..1f6d59c0c8 100644 --- a/src/pl_bolts/callbacks/verification/base.py +++ b/src/pl_bolts/callbacks/verification/base.py @@ -77,6 +77,7 @@ def _model_forward(self, input_array: Any) -> Any: Returns: The output of the model. + """ if isinstance(input_array, tuple): return self.model(*input_array) @@ -105,8 +106,8 @@ def __init__(self, warn: bool = True, error: bool = False) -> None: self._raise_error = error def message(self, *args: Any, **kwargs: Any) -> str: - """The message to be printed when the model does not pass the verification. If the message for warning and - error differ, override the :meth:`warning_message` and :meth:`error_message` methods directly. + """The message to be printed when the model does not pass the verification. If the message for warning and error + differ, override the :meth:`warning_message` and :meth:`error_message` methods directly. Arguments: *args: Any positional arguments that are needed to construct the message. @@ -114,6 +115,7 @@ def message(self, *args: Any, **kwargs: Any) -> str: Returns: The message as a string. + """ def warning_message(self, *args: Any, **kwargs: Any) -> str: diff --git a/src/pl_bolts/callbacks/verification/batch_gradient.py b/src/pl_bolts/callbacks/verification/batch_gradient.py index 834184c152..107c647d9f 100644 --- a/src/pl_bolts/callbacks/verification/batch_gradient.py +++ b/src/pl_bolts/callbacks/verification/batch_gradient.py @@ -91,6 +91,7 @@ class BatchGradientVerificationCallback(VerificationCallbackBase): """The callback version of the :class:`BatchGradientVerification` test. Verification is performed right before training begins. + """ def __init__( @@ -211,12 +212,13 @@ def collect_batches(tensor: Tensor) -> Tensor: @under_review() @contextmanager def selective_eval(model: nn.Module, layer_types: Iterable[Type[nn.Module]]) -> None: - """A context manager that sets all requested types of layers to eval mode. This method uses an ``isinstance`` - check, so all subclasses are also affected. + """A context manager that sets all requested types of layers to eval mode. This method uses an ``isinstance`` check, + so all subclasses are also affected. Args: model: A model which has layers that need to be set to eval mode. layer_types: The list of class objects for which all layers of that type will be set to eval mode. + """ to_revert = [] try: diff --git a/src/pl_bolts/callbacks/vision/image_generation.py b/src/pl_bolts/callbacks/vision/image_generation.py index a30e78972b..0e860292de 100644 --- a/src/pl_bolts/callbacks/vision/image_generation.py +++ b/src/pl_bolts/callbacks/vision/image_generation.py @@ -31,6 +31,7 @@ class TensorboardGenerativeModelImageSampler(Callback): from pl_bolts.callbacks import TensorboardGenerativeModelImageSampler trainer = Trainer(callbacks=[TensorboardGenerativeModelImageSampler()]) + """ def __init__( diff --git a/src/pl_bolts/callbacks/vision/sr_image_logger.py b/src/pl_bolts/callbacks/vision/sr_image_logger.py index f27bd294e2..4a96330013 100644 --- a/src/pl_bolts/callbacks/vision/sr_image_logger.py +++ b/src/pl_bolts/callbacks/vision/sr_image_logger.py @@ -17,8 +17,8 @@ @under_review() class SRImageLoggerCallback(Callback): - """Logs low-res, generated high-res, and ground truth high-res images to TensorBoard Your model must implement - the ``forward`` function for generation. + """Logs low-res, generated high-res, and ground truth high-res images to TensorBoard Your model must implement the + ``forward`` function for generation. Requirements:: @@ -30,6 +30,7 @@ class SRImageLoggerCallback(Callback): from pl_bolts.callbacks import SRImageLoggerCallback trainer = Trainer(callbacks=[SRImageLoggerCallback()]) + """ def __init__(self, log_interval: int = 1000, scale_factor: int = 4, num_samples: int = 5) -> None: diff --git a/src/pl_bolts/datamodules/experience_source.py b/src/pl_bolts/datamodules/experience_source.py index 2a0d4467e4..7967b5ca46 100644 --- a/src/pl_bolts/datamodules/experience_source.py +++ b/src/pl_bolts/datamodules/experience_source.py @@ -1,5 +1,6 @@ """Datamodules for RL models that rely on experiences generated during training Based on implementations found here: https://github.com/Shmuma/ptan/blob/master/ptan/experience.py.""" + from abc import ABC from collections import deque, namedtuple from typing import Callable, Iterator, List, Tuple diff --git a/src/pl_bolts/datamodules/imagenet_datamodule.py b/src/pl_bolts/datamodules/imagenet_datamodule.py index 90f2bb641d..3fec1bab68 100644 --- a/src/pl_bolts/datamodules/imagenet_datamodule.py +++ b/src/pl_bolts/datamodules/imagenet_datamodule.py @@ -167,12 +167,12 @@ def train_dataloader(self) -> DataLoader: return loader def val_dataloader(self) -> DataLoader: - """Uses the part of the train split of imagenet2012 that was not used for training via - `num_imgs_per_val_class` + """Uses the part of the train split of imagenet2012 that was not used for training via `num_imgs_per_val_class` Args: batch_size: the batch size transforms: the transforms + """ transforms = self.val_transform() if self.val_transforms is None else self.val_transforms diff --git a/src/pl_bolts/datamodules/stl10_datamodule.py b/src/pl_bolts/datamodules/stl10_datamodule.py index 158baee0bc..bff8113256 100644 --- a/src/pl_bolts/datamodules/stl10_datamodule.py +++ b/src/pl_bolts/datamodules/stl10_datamodule.py @@ -139,6 +139,7 @@ def train_dataloader_mixed(self) -> DataLoader: batch_size: the batch size transforms: a sequence of transforms + """ transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms diff --git a/src/pl_bolts/datamodules/vocdetection_datamodule.py b/src/pl_bolts/datamodules/vocdetection_datamodule.py index de8a84b0f3..d6435bcc91 100644 --- a/src/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/src/pl_bolts/datamodules/vocdetection_datamodule.py @@ -166,6 +166,7 @@ def train_dataloader(self, image_transforms: Optional[Callable] = None) -> DataL Args: image_transforms: custom image-only transforms + """ transforms = [ _prepare_voc_instance, @@ -181,6 +182,7 @@ def val_dataloader(self, image_transforms: Optional[Callable] = None) -> DataLoa Args: image_transforms: custom image-only transforms + """ transforms = [ _prepare_voc_instance, diff --git a/src/pl_bolts/datasets/kitti_dataset.py b/src/pl_bolts/datasets/kitti_dataset.py index 0e6674224d..92b08ba9c3 100644 --- a/src/pl_bolts/datasets/kitti_dataset.py +++ b/src/pl_bolts/datasets/kitti_dataset.py @@ -86,6 +86,7 @@ def encode_segmap(self, mask): It also sets all of the valid pixels to the appropriate value between 0 and `len(valid_labels)` (the number of valid classes), so it can be used properly by the loss function when comparing with the output. + """ for voidc in self.void_labels: mask[mask == voidc] = self.ignore_index diff --git a/src/pl_bolts/datasets/sr_dataset_mixin.py b/src/pl_bolts/datasets/sr_dataset_mixin.py index cdeddce054..07570a0e69 100644 --- a/src/pl_bolts/datasets/sr_dataset_mixin.py +++ b/src/pl_bolts/datasets/sr_dataset_mixin.py @@ -1,4 +1,5 @@ """Adapted from: https://github.com/https-deeplearning-ai/GANs-Public.""" + from typing import Any, Tuple import torch diff --git a/src/pl_bolts/datasets/utils.py b/src/pl_bolts/datasets/utils.py index e3b085fbc1..52a25d46ea 100644 --- a/src/pl_bolts/datasets/utils.py +++ b/src/pl_bolts/datasets/utils.py @@ -60,6 +60,7 @@ def to_tensor(arrays: TArrays) -> torch.Tensor: Returns: Tensor of the integers + """ return torch.tensor(arrays) diff --git a/src/pl_bolts/models/__init__.py b/src/pl_bolts/models/__init__.py index 66f237728d..7ddbd599bb 100644 --- a/src/pl_bolts/models/__init__.py +++ b/src/pl_bolts/models/__init__.py @@ -1,4 +1,5 @@ """Collection of PyTorchLightning models.""" + from pl_bolts.models.autoencoders.basic_ae.basic_ae_module import AE from pl_bolts.models.autoencoders.basic_vae.basic_vae_module import VAE from pl_bolts.models.mnist_module import LitMNIST diff --git a/src/pl_bolts/models/detection/yolo/darknet_network.py b/src/pl_bolts/models/detection/yolo/darknet_network.py index 7a38a0f5d2..78f87d62d9 100644 --- a/src/pl_bolts/models/detection/yolo/darknet_network.py +++ b/src/pl_bolts/models/detection/yolo/darknet_network.py @@ -145,6 +145,7 @@ def read(tensor: Tensor) -> int: """Reads the contents of ``tensor`` from the current position of ``weight_file``. Returns the number of elements read. If there's no more data in ``weight_file``, returns 0. + """ np_array = np.fromfile(weight_file, count=tensor.numel(), dtype=np.float32) num_elements = np_array.size @@ -275,8 +276,8 @@ def convert(key: str, value: str) -> Union[str, int, float, List[Union[str, int, def _create_layer(config: CONFIG, num_inputs: List[int], **kwargs: Any) -> CREATE_LAYER_OUTPUT: - """Calls one of the ``_create_(config, num_inputs)`` functions to create a PyTorch module from the - layer config. + """Calls one of the ``_create_(config, num_inputs)`` functions to create a PyTorch module from the layer + config. Args: config: Dictionary of configuration options for this layer. @@ -285,6 +286,7 @@ def _create_layer(config: CONFIG, num_inputs: List[int], **kwargs: Any) -> CREAT Returns: module (:class:`~torch.nn.Module`), num_outputs (int): The created PyTorch module and the number of channels in its output. + """ create_func: Dict[str, Callable[..., CREATE_LAYER_OUTPUT]] = { "convolutional": _create_convolutional, diff --git a/src/pl_bolts/models/detection/yolo/loss.py b/src/pl_bolts/models/detection/yolo/loss.py index 44ac5b0f11..6bcaadd0de 100644 --- a/src/pl_bolts/models/detection/yolo/loss.py +++ b/src/pl_bolts/models/detection/yolo/loss.py @@ -205,6 +205,7 @@ def _target_labels_to_probs( Returns: An ``[M, C]`` matrix of target class probabilities. + """ if targets.ndim == 1: # The data may contain a different number of classes than what the model predicts. In case a label is diff --git a/src/pl_bolts/models/detection/yolo/torch_networks.py b/src/pl_bolts/models/detection/yolo/torch_networks.py index ee5358ac7f..9e59eec796 100644 --- a/src/pl_bolts/models/detection/yolo/torch_networks.py +++ b/src/pl_bolts/models/detection/yolo/torch_networks.py @@ -31,6 +31,7 @@ def run_detection( detections: A list where a tensor containing the detections will be appended to. losses: A list where a tensor containing the losses will be appended to, if ``targets`` is given. hits: A list where the number of targets that matched this layer will be appended to, if ``targets`` is given. + """ output, preds = detection_layer(layer_input, image_size) detections.append(output) @@ -69,6 +70,7 @@ def run_detection_with_aux_head( detections: A list where a tensor containing the detections will be appended to. losses: A list where a tensor containing the losses will be appended to, if ``targets`` is given. hits: A list where the number of targets that matched this layer will be appended to, if ``targets`` is given. + """ output, preds = detection_layer(layer_input, image_size) detections.append(output) @@ -1132,8 +1134,8 @@ def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPU class YOLOV5Network(nn.Module): - """The YOLOv5 network architecture. Different variants (n/s/m/l/x) can be achieved by adjusting the ``depth`` - and ``width`` parameters. + """The YOLOv5 network architecture. Different variants (n/s/m/l/x) can be achieved by adjusting the ``depth`` and + ``width`` parameters. Args: num_classes: Number of different classes that this model predicts. @@ -1176,6 +1178,7 @@ class YOLOV5Network(nn.Module): class_loss_multiplier: Classification loss will be scaled by this value. xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps to produce coordinate values close to one. + """ def __init__( @@ -1613,8 +1616,8 @@ def forward(self, x: Tensor) -> Tensor: class YOLOXNetwork(nn.Module): - """The YOLOX network architecture. Different variants (nano/tiny/s/m/l/x) can be achieved by adjusting the - ``depth`` and ``width`` parameters. + """The YOLOX network architecture. Different variants (nano/tiny/s/m/l/x) can be achieved by adjusting the ``depth`` + and ``width`` parameters. Args: num_classes: Number of different classes that this model predicts. @@ -1657,6 +1660,7 @@ class YOLOXNetwork(nn.Module): class_loss_multiplier: Classification loss will be scaled by this value. xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps to produce coordinate values close to one. + """ def __init__( diff --git a/src/pl_bolts/models/detection/yolo/utils.py b/src/pl_bolts/models/detection/yolo/utils.py index d981fadceb..66996fc4d8 100644 --- a/src/pl_bolts/models/detection/yolo/utils.py +++ b/src/pl_bolts/models/detection/yolo/utils.py @@ -102,8 +102,8 @@ def aligned_iou(wh1: Tensor, wh2: Tensor) -> Tensor: def iou_below(pred_boxes: Tensor, target_boxes: Tensor, threshold: float) -> Tensor: - """Creates a binary mask whose value will be ``True``, unless the predicted box overlaps any target - significantly (IoU greater than ``threshold``). + """Creates a binary mask whose value will be ``True``, unless the predicted box overlaps any target significantly + (IoU greater than ``threshold``). Args: pred_boxes: The predicted corner coordinates. Tensor of size ``[height, width, boxes_per_cell, 4]``. @@ -112,6 +112,7 @@ def iou_below(pred_boxes: Tensor, target_boxes: Tensor, threshold: float) -> Ten Returns: A boolean tensor sized ``[height, width, boxes_per_cell]``, with ``False`` where the predicted box overlaps a target significantly and ``True`` elsewhere. + """ shape = pred_boxes.shape[:-1] pred_boxes = pred_boxes.view(-1, 4) diff --git a/src/pl_bolts/models/detection/yolo/yolo_module.py b/src/pl_bolts/models/detection/yolo/yolo_module.py index 429ed087b7..558cb496b7 100644 --- a/src/pl_bolts/models/detection/yolo/yolo_module.py +++ b/src/pl_bolts/models/detection/yolo/yolo_module.py @@ -144,8 +144,8 @@ def __init__( def forward( self, images: Union[Tensor, IMAGES], targets: Optional[TARGETS] = None ) -> Union[Tensor, Tuple[Tensor, Tensor]]: - """Runs a forward pass through the network (all layers listed in ``self.network``), and if training targets - are provided, computes the losses from the detection layers. + """Runs a forward pass through the network (all layers listed in ``self.network``), and if training targets are + provided, computes the losses from the detection layers. Detections are concatenated from the detection layers. Each detection layer will produce a number of detections that depends on the size of the feature map and the number of anchors per feature map cell. @@ -161,6 +161,7 @@ def forward( provided, a dictionary of losses. Detections are shaped ``[batch_size, anchors, classes + 5]``, where ``anchors`` is the feature map size (width * height) times the number of anchors per cell. The predicted box coordinates are in `(x1, y1, x2, y2)` format and scaled to the input image size. + """ self.validate_batch(images, targets) images_tensor = images if isinstance(images, Tensor) else torch.stack(images) @@ -185,6 +186,7 @@ def configure_optimizers(self) -> Tuple[List[optim.Optimizer], List[LRScheduler] If weight decay is specified, it will be applied only to convolutional layer weights, as they contain much more parameters than the biases and batch normalization parameters. Regularizing all parameters could lead to underfitting. + """ if ("weight_decay" in self.optimizer_params) and (self.optimizer_params["weight_decay"] != 0): defaults = copy(self.optimizer_params) @@ -574,12 +576,13 @@ def __init__( class ResizedVOCDetectionDataModule(VOCDetectionDataModule): - """A subclass of ``VOCDetectionDataModule`` that resizes the images to a specific size. YOLO expectes the image - size to be divisible by the ratio in which the network downsamples the image. + """A subclass of ``VOCDetectionDataModule`` that resizes the images to a specific size. YOLO expectes the image size + to be divisible by the ratio in which the network downsamples the image. Args: width: Resize images to this width. height: Resize images to this height. + """ def __init__(self, width: int = 608, height: int = 608, **kwargs: Any): @@ -609,6 +612,7 @@ def _resize(self, image: Tensor, target: TARGET) -> Tuple[Tensor, TARGET]: Returns: Resized image tensor. + """ device = target["boxes"].device height, width = image.shape[-2:] diff --git a/src/pl_bolts/models/gans/srgan/components.py b/src/pl_bolts/models/gans/srgan/components.py index 99ad9d2e6a..63a531c46a 100644 --- a/src/pl_bolts/models/gans/srgan/components.py +++ b/src/pl_bolts/models/gans/srgan/components.py @@ -1,4 +1,5 @@ """Adapted from: https://github.com/https-deeplearning-ai/GANs-Public.""" + import torch import torch.nn as nn diff --git a/src/pl_bolts/models/gans/srgan/srgan_module.py b/src/pl_bolts/models/gans/srgan/srgan_module.py index ef11f10dc2..2799aa5c29 100644 --- a/src/pl_bolts/models/gans/srgan/srgan_module.py +++ b/src/pl_bolts/models/gans/srgan/srgan_module.py @@ -1,4 +1,5 @@ """Adapted from: https://github.com/https-deeplearning-ai/GANs-Public.""" + from argparse import ArgumentParser from pathlib import Path from typing import Any, List, Optional, Tuple diff --git a/src/pl_bolts/models/gans/srgan/srresnet_module.py b/src/pl_bolts/models/gans/srgan/srresnet_module.py index fc6ba2498b..1545e63391 100644 --- a/src/pl_bolts/models/gans/srgan/srresnet_module.py +++ b/src/pl_bolts/models/gans/srgan/srresnet_module.py @@ -1,4 +1,5 @@ """Adapted from: https://github.com/https-deeplearning-ai/GANs-Public.""" + from argparse import ArgumentParser from typing import Any, Tuple diff --git a/src/pl_bolts/models/rl/advantage_actor_critic_model.py b/src/pl_bolts/models/rl/advantage_actor_critic_model.py index e4863e32fe..30c73e446d 100644 --- a/src/pl_bolts/models/rl/advantage_actor_critic_model.py +++ b/src/pl_bolts/models/rl/advantage_actor_critic_model.py @@ -1,4 +1,5 @@ """Advantage Actor Critic (A2C)""" + from argparse import ArgumentParser from collections import OrderedDict from typing import Any, Iterator, List, Tuple diff --git a/src/pl_bolts/models/rl/common/agents.py b/src/pl_bolts/models/rl/common/agents.py index 116b0b89dd..ad3746bbaf 100644 --- a/src/pl_bolts/models/rl/common/agents.py +++ b/src/pl_bolts/models/rl/common/agents.py @@ -2,6 +2,7 @@ https://github.com/Shmuma/ptan/blob/master/ptan/agent.py. """ + from abc import ABC from typing import List diff --git a/src/pl_bolts/models/rl/common/distributions.py b/src/pl_bolts/models/rl/common/distributions.py index c589c2db3a..495fbb0818 100644 --- a/src/pl_bolts/models/rl/common/distributions.py +++ b/src/pl_bolts/models/rl/common/distributions.py @@ -1,4 +1,5 @@ """Distributions used in some continuous RL algorithms.""" + import torch from pl_bolts.utils.stability import under_review diff --git a/src/pl_bolts/models/rl/common/gym_wrappers.py b/src/pl_bolts/models/rl/common/gym_wrappers.py index 605b498a7a..573def6fa5 100644 --- a/src/pl_bolts/models/rl/common/gym_wrappers.py +++ b/src/pl_bolts/models/rl/common/gym_wrappers.py @@ -1,5 +1,6 @@ """Set of wrapper functions for gym environments taken from https://github.com/Shmuma/ptan/blob/master/ptan/common/wrappers.py.""" + import collections import numpy as np diff --git a/src/pl_bolts/models/rl/common/networks.py b/src/pl_bolts/models/rl/common/networks.py index 63aad43a11..b920ae24ff 100644 --- a/src/pl_bolts/models/rl/common/networks.py +++ b/src/pl_bolts/models/rl/common/networks.py @@ -1,4 +1,5 @@ """Series of networks used Based on implementations found here:""" + import math from typing import Tuple diff --git a/src/pl_bolts/models/rl/double_dqn_model.py b/src/pl_bolts/models/rl/double_dqn_model.py index 2d76279c87..b2d36ca0c2 100644 --- a/src/pl_bolts/models/rl/double_dqn_model.py +++ b/src/pl_bolts/models/rl/double_dqn_model.py @@ -1,4 +1,5 @@ """Double DQN.""" + import argparse from collections import OrderedDict from typing import Tuple diff --git a/src/pl_bolts/models/rl/dqn_model.py b/src/pl_bolts/models/rl/dqn_model.py index 567aa8d185..bfafce3997 100644 --- a/src/pl_bolts/models/rl/dqn_model.py +++ b/src/pl_bolts/models/rl/dqn_model.py @@ -1,4 +1,5 @@ """Deep Q Network.""" + import argparse from collections import OrderedDict from typing import Dict, List, Optional, Tuple diff --git a/src/pl_bolts/models/rl/dueling_dqn_model.py b/src/pl_bolts/models/rl/dueling_dqn_model.py index 1e072d5ff8..d7e2b939e3 100644 --- a/src/pl_bolts/models/rl/dueling_dqn_model.py +++ b/src/pl_bolts/models/rl/dueling_dqn_model.py @@ -1,4 +1,5 @@ """Dueling DQN.""" + import argparse from pytorch_lightning import Trainer diff --git a/src/pl_bolts/models/rl/noisy_dqn_model.py b/src/pl_bolts/models/rl/noisy_dqn_model.py index 76b4531c5b..bfb877cd8e 100644 --- a/src/pl_bolts/models/rl/noisy_dqn_model.py +++ b/src/pl_bolts/models/rl/noisy_dqn_model.py @@ -1,4 +1,5 @@ """Noisy DQN.""" + import argparse from typing import Tuple diff --git a/src/pl_bolts/models/rl/per_dqn_model.py b/src/pl_bolts/models/rl/per_dqn_model.py index a864afb51b..1440587421 100644 --- a/src/pl_bolts/models/rl/per_dqn_model.py +++ b/src/pl_bolts/models/rl/per_dqn_model.py @@ -1,4 +1,5 @@ """Prioritized Experience Replay DQN.""" + import argparse from collections import OrderedDict from typing import Tuple diff --git a/src/pl_bolts/models/rl/ppo_model.py b/src/pl_bolts/models/rl/ppo_model.py index 21bc0873c0..b861a66619 100644 --- a/src/pl_bolts/models/rl/ppo_model.py +++ b/src/pl_bolts/models/rl/ppo_model.py @@ -319,8 +319,7 @@ def configure_optimizers(self) -> List[Optimizer]: return optimizer_actor, optimizer_critic def optimizer_step(self, *args, **kwargs): - """Run ``num_optim_iters`` number of iterations of gradient descent on actor and critic for each data - sample.""" + """Run ``num_optim_iters`` number of iterations of gradient descent on actor and critic for each data sample.""" for _ in range(self.num_optim_iters): super().optimizer_step(*args, **kwargs) diff --git a/src/pl_bolts/models/rl/sac_model.py b/src/pl_bolts/models/rl/sac_model.py index 8c0bb2b712..56aba5f530 100644 --- a/src/pl_bolts/models/rl/sac_model.py +++ b/src/pl_bolts/models/rl/sac_model.py @@ -1,4 +1,5 @@ """Soft Actor Critic.""" + import argparse from typing import Dict, List, Tuple diff --git a/src/pl_bolts/models/self_supervised/__init__.py b/src/pl_bolts/models/self_supervised/__init__.py index f501ee73a0..ede1cbf848 100644 --- a/src/pl_bolts/models/self_supervised/__init__.py +++ b/src/pl_bolts/models/self_supervised/__init__.py @@ -17,6 +17,7 @@ classifications = classifier(representations) """ + from pl_bolts.models.self_supervised.amdim.amdim_module import AMDIM from pl_bolts.models.self_supervised.byol.byol_module import BYOL from pl_bolts.models.self_supervised.cpc.cpc_module import CPC_v2 diff --git a/src/pl_bolts/models/self_supervised/cpc/cpc_module.py b/src/pl_bolts/models/self_supervised/cpc/cpc_module.py index 6f60b7a267..33317c9706 100644 --- a/src/pl_bolts/models/self_supervised/cpc/cpc_module.py +++ b/src/pl_bolts/models/self_supervised/cpc/cpc_module.py @@ -1,4 +1,5 @@ """CPC V2.""" + import math from argparse import ArgumentParser from typing import Optional diff --git a/src/pl_bolts/models/self_supervised/moco/moco_module.py b/src/pl_bolts/models/self_supervised/moco/moco_module.py index 86d01147c3..19f6d99c63 100644 --- a/src/pl_bolts/models/self_supervised/moco/moco_module.py +++ b/src/pl_bolts/models/self_supervised/moco/moco_module.py @@ -8,6 +8,7 @@ You may obtain a copy of the License from the LICENSE file present in this folder. """ + from copy import copy, deepcopy from typing import Any, Dict, List, Optional, Tuple, Type, Union @@ -228,6 +229,7 @@ def configure_optimizers(self) -> Tuple[List[optim.Optimizer], List[optim.lr_sch ``self.lr_scheduler_params``. If weight decay is specified, it will be applied only to convolutional layer weights. + """ if ( ("weight_decay" in self.optimizer_params) diff --git a/src/pl_bolts/models/self_supervised/moco/utils.py b/src/pl_bolts/models/self_supervised/moco/utils.py index 116b52f979..030ec7079d 100644 --- a/src/pl_bolts/models/self_supervised/moco/utils.py +++ b/src/pl_bolts/models/self_supervised/moco/utils.py @@ -101,6 +101,7 @@ def concatenate_all(tensor: Tensor) -> Tensor: """Performs ``all_gather`` operation to concatenate the provided tensor from all devices. This function has no gradient. + """ gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] torch.distributed.all_gather(gathered_tensor, tensor.contiguous()) diff --git a/src/pl_bolts/models/self_supervised/swav/swav_module.py b/src/pl_bolts/models/self_supervised/swav/swav_module.py index e212358c40..fe5990676d 100644 --- a/src/pl_bolts/models/self_supervised/swav/swav_module.py +++ b/src/pl_bolts/models/self_supervised/swav/swav_module.py @@ -1,4 +1,5 @@ """Adapted from official swav implementation: https://github.com/facebookresearch/swav.""" + import os from argparse import ArgumentParser diff --git a/src/pl_bolts/models/self_supervised/swav/swav_resnet.py b/src/pl_bolts/models/self_supervised/swav/swav_resnet.py index 65b09dbe9a..2c2ffe5d96 100644 --- a/src/pl_bolts/models/self_supervised/swav/swav_resnet.py +++ b/src/pl_bolts/models/self_supervised/swav/swav_resnet.py @@ -1,4 +1,5 @@ """Adapted from: https://github.com/facebookresearch/swav/blob/master/src/resnet50.py.""" + import torch from torch import nn diff --git a/src/pl_bolts/models/vision/pixel_cnn.py b/src/pl_bolts/models/vision/pixel_cnn.py index 70be9dcd19..64edc93775 100644 --- a/src/pl_bolts/models/vision/pixel_cnn.py +++ b/src/pl_bolts/models/vision/pixel_cnn.py @@ -4,6 +4,7 @@ : https: //arxiv.org/pdf/1905.09272.pdf (page 15 Accessed: May 14, 2020. """ + from torch import nn from torch.nn import functional as F # noqa: N812 diff --git a/src/pl_bolts/optimizers/lars.py b/src/pl_bolts/optimizers/lars.py index 58ed202f24..7b9e41c45c 100644 --- a/src/pl_bolts/optimizers/lars.py +++ b/src/pl_bolts/optimizers/lars.py @@ -3,6 +3,7 @@ - https://arxiv.org/pdf/1708.03888.pdf - https://github.com/pytorch/pytorch/blob/1.6/torch/optim/sgd.py """ + import torch from torch.optim.optimizer import Optimizer, required diff --git a/tests/models/rl/unit/test_agents.py b/tests/models/rl/unit/test_agents.py index e414f439f5..3a029d346b 100644 --- a/tests/models/rl/unit/test_agents.py +++ b/tests/models/rl/unit/test_agents.py @@ -1,4 +1,5 @@ """Tests that the agent module works correctly.""" + from unittest import TestCase from unittest.mock import Mock