Skip to content

Commit

Permalink
Merge branch 'add_gpu_encode' of https://github.com/deekay42/vision i…
Browse files Browse the repository at this point in the history
…nto add_gpu_encode
  • Loading branch information
deekay42 committed Jun 11, 2024
2 parents 0a88d27 + 136f790 commit df60183
Show file tree
Hide file tree
Showing 18 changed files with 182 additions and 60 deletions.
File renamed without changes.
13 changes: 10 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import shutil
import subprocess
import sys
import warnings

import torch
from pkg_resources import DistributionNotFound, get_distribution, parse_version
Expand Down Expand Up @@ -138,7 +139,6 @@ def get_extensions():
+ glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp"))
+ glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp"))
)
source_mps = glob.glob(os.path.join(extensions_dir, "ops", "mps", "*.mm"))

print("Compiling extensions with following flags:")
force_cuda = os.getenv("FORCE_CUDA", "0") == "1"
Expand Down Expand Up @@ -204,8 +204,15 @@ def get_extensions():
define_macros += [("WITH_HIP", None)]
nvcc_flags = []
extra_compile_args["nvcc"] = nvcc_flags
elif torch.backends.mps.is_available() or force_mps:
sources += source_mps

# FIXME: MPS build breaks custom ops registration, so it was disabled.
# See https://github.com/pytorch/vision/issues/8456.
# TODO: Fix MPS build, remove warning below, and put back commented-out elif block.V
if force_mps:
warnings.warn("MPS build is temporarily disabled!!!!")
# elif torch.backends.mps.is_available() or force_mps:
# source_mps = glob.glob(os.path.join(extensions_dir, "ops", "mps", "*.mm"))
# sources += source_mps

if sys.platform == "win32":
define_macros += [("torchvision_EXPORTS", None)]
Expand Down
3 changes: 2 additions & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def pytest_collection_modifyitems(items):
# There are special cases though, see below
item.add_marker(pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG))

if needs_mps and not torch.backends.mps.is_available():
# TODO: uncoment when MPS works again - see FIXME in setup.py
if needs_mps: # and not torch.backends.mps.is_available():
item.add_marker(pytest.mark.skip(reason=MPS_NOT_AVAILABLE_MSG))

if IN_FBCODE:
Expand Down
5 changes: 4 additions & 1 deletion test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import requests
import torch
import torchvision.transforms.functional as F
from common_utils import assert_equal, cpu_and_cuda, needs_cuda
from common_utils import assert_equal, cpu_and_cuda, IN_OSS_CI, needs_cuda
from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence
from torchvision.io.image import (
_read_png_16,
Expand Down Expand Up @@ -764,6 +764,9 @@ def test_decode_gif(tmpdir, name, scripted):

path = tmpdir / f"{name}.gif"
if name == "earth":
if IN_OSS_CI:
# TODO: Fix this... one day.
pytest.skip("Skipping 'earth' test as it's flaky on OSS CI")
url = "https://upload.wikimedia.org/wikipedia/commons/2/2c/Rotating_earth_%28large%29.gif"
else:
url = f"https://sourceforge.net/p/giflib/code/ci/master/tree/pic/{name}.gif?format=raw"
Expand Down
63 changes: 54 additions & 9 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _script(obj):
return torch.jit.script(obj)
except Exception as error:
name = getattr(obj, "__name__", obj.__class__.__name__)
raise AssertionError(f"Trying to `torch.jit.script` '{name}' raised the error above.") from error
raise AssertionError(f"Trying to `torch.jit.script` `{name}` raised the error above.") from error


def _check_kernel_scripted_vs_eager(kernel, input, *args, rtol, atol, **kwargs):
Expand Down Expand Up @@ -553,10 +553,12 @@ def affine_bounding_boxes(bounding_boxes):

class TestResize:
INPUT_SIZE = (17, 11)
OUTPUT_SIZES = [17, [17], (17,), [12, 13], (12, 13)]
OUTPUT_SIZES = [17, [17], (17,), None, [12, 13], (12, 13)]

def _make_max_size_kwarg(self, *, use_max_size, size):
if use_max_size:
if size is None:
max_size = min(list(self.INPUT_SIZE))
elif use_max_size:
if not (isinstance(size, int) or len(size) == 1):
# This would result in an `ValueError`
return None
Expand All @@ -568,10 +570,13 @@ def _make_max_size_kwarg(self, *, use_max_size, size):
return dict(max_size=max_size)

def _compute_output_size(self, *, input_size, size, max_size):
if not (isinstance(size, int) or len(size) == 1):
if size is None:
size = max_size

elif not (isinstance(size, int) or len(size) == 1):
return tuple(size)

if not isinstance(size, int):
elif not isinstance(size, int):
size = size[0]

old_height, old_width = input_size
Expand Down Expand Up @@ -658,10 +663,13 @@ def test_kernel_video(self):
[make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
)
def test_functional(self, size, make_input):
max_size_kwarg = self._make_max_size_kwarg(use_max_size=size is None, size=size)

check_functional(
F.resize,
make_input(self.INPUT_SIZE),
size=size,
**max_size_kwarg,
antialias=True,
check_scripted_smoke=not isinstance(size, int),
)
Expand Down Expand Up @@ -695,11 +703,13 @@ def test_functional_signature(self, kernel, input_type):
],
)
def test_transform(self, size, device, make_input):
max_size_kwarg = self._make_max_size_kwarg(use_max_size=size is None, size=size)

check_transform(
transforms.Resize(size=size, antialias=True),
transforms.Resize(size=size, **max_size_kwarg, antialias=True),
make_input(self.INPUT_SIZE, device=device),
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
check_v1_compatibility=dict(rtol=0, atol=1),
check_v1_compatibility=dict(rtol=0, atol=1) if size is not None else False,
)

def _check_output_size(self, input, output, *, size, max_size):
Expand Down Expand Up @@ -801,7 +811,11 @@ def test_functional_pil_antialias_warning(self):
],
)
def test_max_size_error(self, size, make_input):
if isinstance(size, int) or len(size) == 1:
if size is None:
# value can be anything other than an integer
max_size = None
match = "max_size must be an integer when size is None"
elif isinstance(size, int) or len(size) == 1:
max_size = (size if isinstance(size, int) else size[0]) - 1
match = "must be strictly greater than the requested size"
else:
Expand All @@ -812,6 +826,37 @@ def test_max_size_error(self, size, make_input):
with pytest.raises(ValueError, match=match):
F.resize(make_input(self.INPUT_SIZE), size=size, max_size=max_size, antialias=True)

if isinstance(size, list) and len(size) != 1:
with pytest.raises(ValueError, match="max_size should only be passed if size is None or specifies"):
F.resize(make_input(self.INPUT_SIZE), size=size, max_size=500)

@pytest.mark.parametrize(
"input_size, max_size, expected_size",
[
((10, 10), 10, (10, 10)),
((10, 20), 40, (20, 40)),
((20, 10), 40, (40, 20)),
((10, 20), 10, (5, 10)),
((20, 10), 10, (10, 5)),
],
)
@pytest.mark.parametrize(
"make_input",
[
make_image_tensor,
make_image_pil,
make_image,
make_bounding_boxes,
make_segmentation_mask,
make_detection_masks,
make_video,
],
)
def test_resize_size_none(self, input_size, max_size, expected_size, make_input):
img = make_input(input_size)
out = F.resize(img, size=None, max_size=max_size)
assert F.get_size(out)[-2:] == list(expected_size)

@pytest.mark.parametrize("interpolation", INTERPOLATION_MODES)
@pytest.mark.parametrize(
"make_input",
Expand All @@ -834,7 +879,7 @@ def test_interpolation_int(self, interpolation, make_input):
assert_equal(actual, expected)

def test_transform_unknown_size_error(self):
with pytest.raises(ValueError, match="size can either be an integer or a sequence of one or two integers"):
with pytest.raises(ValueError, match="size can be an integer, a sequence of one or two integers, or None"):
transforms.Resize(size=object())

@pytest.mark.parametrize(
Expand Down
6 changes: 4 additions & 2 deletions torchvision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from modulefinder import Module

import torch
from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils

from .extension import _HAS_OPS
# Don't re-order these, we need to load the _C extension (done when importing
# .extensions) before entering _meta_registrations.
from .extension import _HAS_OPS # usort:skip
from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils # usort:skip

try:
from .version import __version__ # noqa: F401
Expand Down
2 changes: 1 addition & 1 deletion torchvision/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def meta_ps_roi_pool_backward(
return grad.new_empty((batch_size, channels, height, width))


@torch._custom_ops.impl_abstract("torchvision::nms")
@torch.library.register_fake("torchvision::nms")
def meta_nms(dets, scores, iou_threshold):
torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D")
torch._check(dets.size(1) == 4, lambda: f"boxes should have 4 elements in dimension 1, got {dets.size(1)}")
Expand Down
18 changes: 18 additions & 0 deletions torchvision/csrc/io/decoder/audio_sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,23 @@ bool AudioSampler::init(const SamplerParameters& params) {
return false;
}

#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(57, 28, 100)
SwrContext* swrContext_ = NULL;
AVChannelLayout channel_out;
AVChannelLayout channel_in;
av_channel_layout_default(&channel_out, params.out.audio.channels);
av_channel_layout_default(&channel_in, params.in.audio.channels);
int ret = swr_alloc_set_opts2(
&swrContext_,
&channel_out,
(AVSampleFormat)params.out.audio.format,
params.out.audio.samples,
&channel_in,
(AVSampleFormat)params.in.audio.format,
params.in.audio.samples,
0,
logCtx_);
#else
swrContext_ = swr_alloc_set_opts(
nullptr,
av_get_default_channel_layout(params.out.audio.channels),
Expand All @@ -58,6 +75,7 @@ bool AudioSampler::init(const SamplerParameters& params) {
params.in.audio.samples,
0,
logCtx_);
#endif
if (swrContext_ == nullptr) {
LOG(ERROR) << "Cannot allocate SwrContext";
return false;
Expand Down
24 changes: 20 additions & 4 deletions torchvision/csrc/io/decoder/audio_stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,36 @@
namespace ffmpeg {

namespace {
static int get_nb_channels(const AVFrame* frame, const AVCodecContext* codec) {
#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(57, 28, 100)
return frame ? frame->ch_layout.nb_channels : codec->ch_layout.nb_channels;
#else
return frame ? frame->channels : codec->channels;
#endif
}

bool operator==(const AudioFormat& x, const AVFrame& y) {
return x.samples == static_cast<size_t>(y.sample_rate) &&
x.channels == static_cast<size_t>(y.channels) && x.format == y.format;
x.channels == static_cast<size_t>(get_nb_channels(&y, nullptr)) &&
x.format == y.format;
}

bool operator==(const AudioFormat& x, const AVCodecContext& y) {
return x.samples == static_cast<size_t>(y.sample_rate) &&
x.channels == static_cast<size_t>(y.channels) && x.format == y.sample_fmt;
x.channels == static_cast<size_t>(get_nb_channels(nullptr, &y)) &&
x.format == y.sample_fmt;
}

AudioFormat& toAudioFormat(AudioFormat& x, const AVFrame& y) {
x.samples = y.sample_rate;
x.channels = y.channels;
x.channels = get_nb_channels(&y, nullptr);
x.format = y.format;
return x;
}

AudioFormat& toAudioFormat(AudioFormat& x, const AVCodecContext& y) {
x.samples = y.sample_rate;
x.channels = y.channels;
x.channels = get_nb_channels(nullptr, &y);
x.format = y.sample_fmt;
return x;
}
Expand Down Expand Up @@ -54,9 +64,15 @@ int AudioStream::initFormat() {
if (format_.format.audio.samples == 0) {
format_.format.audio.samples = codecCtx_->sample_rate;
}
#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(57, 28, 100)
if (format_.format.audio.channels == 0) {
format_.format.audio.channels = codecCtx_->ch_layout.nb_channels;
}
#else
if (format_.format.audio.channels == 0) {
format_.format.audio.channels = codecCtx_->channels;
}
#endif
if (format_.format.audio.format == AV_SAMPLE_FMT_NONE) {
format_.format.audio.format = codecCtx_->sample_fmt;
}
Expand Down
2 changes: 0 additions & 2 deletions torchvision/csrc/io/video/video.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
#include "../decoder/memory_buffer.h"
#include "../decoder/sync_decoder.h"

using namespace ffmpeg;

namespace vision {
namespace video {

Expand Down
1 change: 1 addition & 0 deletions torchvision/csrc/ops/nms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ at::Tensor nms(
}

TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.set_python_module("torchvision._meta_registrations");
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor"));
}
Expand Down
4 changes: 3 additions & 1 deletion torchvision/datasets/sbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def __init__(
for f in ["cls", "img", "inst", "train.txt", "val.txt"]:
old_path = os.path.join(extracted_ds_root, f)
shutil.move(old_path, sbd_root)
download_url(self.voc_train_url, sbd_root, self.voc_split_filename, self.voc_split_md5)
if self.image_set == "train_noval":
# Note: this is failing as of June 2024 https://github.com/pytorch/vision/issues/8471
download_url(self.voc_train_url, sbd_root, self.voc_split_filename, self.voc_split_md5)

if not os.path.isdir(sbd_root):
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
Expand Down
2 changes: 1 addition & 1 deletion torchvision/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def read_image(
The values of the output tensor are uint8 in [0, 255].
Args:
path (str or ``pathlib.Path``): path of the JPEG or PNG image.
path (str or ``pathlib.Path``): path of the JPEG, PNG or GIF image.
mode (ImageReadMode): the read mode used for optionally converting the image.
Default: ``ImageReadMode.UNCHANGED``.
See ``ImageReadMode`` class for more information on various
Expand Down
1 change: 0 additions & 1 deletion torchvision/ops/roi_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import List, Union

import torch
import torch._dynamo
import torch.fx
from torch import nn, Tensor
from torch._dynamo.utils import is_compile_supported
Expand Down
19 changes: 14 additions & 5 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,13 +351,22 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool


def _compute_resized_output_size(
image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
image_size: Tuple[int, int],
size: Optional[List[int]],
max_size: Optional[int] = None,
allow_size_none: bool = False, # only True in v2
) -> List[int]:
if len(size) == 1: # specified size only for the smallest edge
h, w = image_size
short, long = (w, h) if w <= h else (h, w)
h, w = image_size
short, long = (w, h) if w <= h else (h, w)
if size is None:
if not allow_size_none:
raise ValueError("This should never happen!!")
if not isinstance(max_size, int):
raise ValueError(f"max_size must be an integer when size is None, but got {max_size} instead.")
new_short, new_long = int(max_size * short / long), max_size
new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
elif len(size) == 1: # specified size only for the smallest edge
requested_new_short = size if isinstance(size, int) else size[0]

new_short, new_long = requested_new_short, int(requested_new_short * long / short)

if max_size is not None:
Expand Down
Loading

0 comments on commit df60183

Please sign in to comment.