diff --git a/test/test_image.py b/test/test_image.py index c5784ab8f76..afbf01a6fff 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -9,7 +9,7 @@ import torch import torchvision.transforms.functional as F from common_utils import assert_equal, needs_cuda -from PIL import __version__ as PILLOW_VERSION, Image +from PIL import __version__ as PILLOW_VERSION, Image, ImageOps from torchvision.io.image import ( _read_png_16, decode_image, @@ -100,6 +100,25 @@ def test_decode_jpeg(img_path, pil_mode, mode): assert abs_mean_diff < 2 +@pytest.mark.parametrize("orientation", [1, 2, 3, 4, 5, 6, 7, 8, 0]) +def test_decode_jpeg_with_exif_orientation(tmpdir, orientation): + fp = os.path.join(tmpdir, f"exif_oriented_{orientation}.jpg") + t = torch.randint(0, 256, size=(3, 256, 257), dtype=torch.uint8) + im = F.to_pil_image(t) + exif = im.getexif() + exif[274] = orientation # set exif orientation + im.save(fp, "JPEG", exif=exif.tobytes()) + + data = read_file(fp) + output = decode_image(data, apply_exif_orientation=True) + + pimg = Image.open(fp) + pimg = ImageOps.exif_transpose(pimg) + + expected = F.pil_to_tensor(pimg) + torch.testing.assert_close(expected, output) + + def test_decode_jpeg_errors(): with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"): decode_jpeg(torch.empty((100, 1), dtype=torch.uint8)) diff --git a/torchvision/csrc/io/image/cpu/decode_image.cpp b/torchvision/csrc/io/image/cpu/decode_image.cpp index da4dc5833de..cf8867611de 100644 --- a/torchvision/csrc/io/image/cpu/decode_image.cpp +++ b/torchvision/csrc/io/image/cpu/decode_image.cpp @@ -6,7 +6,10 @@ namespace vision { namespace image { -torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) { +torch::Tensor decode_image( + const torch::Tensor& data, + ImageReadMode mode, + bool apply_exif_orientation) { // Check that tensor is a CPU tensor TORCH_CHECK(data.device() == torch::kCPU, "Expected a CPU tensor"); // Check that the input tensor dtype is uint8 @@ -22,8 +25,11 @@ torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) { const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG" if (memcmp(jpeg_signature, datap, 3) == 0) { - return decode_jpeg(data, mode); + return decode_jpeg(data, mode, apply_exif_orientation); } else if (memcmp(png_signature, datap, 4) == 0) { + TORCH_CHECK( + !apply_exif_orientation, + "Unsupported option apply_exif_orientation=true for PNG") return decode_png(data, mode); } else { TORCH_CHECK( diff --git a/torchvision/csrc/io/image/cpu/decode_image.h b/torchvision/csrc/io/image/cpu/decode_image.h index 853d6d91afa..f0e66d397ac 100644 --- a/torchvision/csrc/io/image/cpu/decode_image.h +++ b/torchvision/csrc/io/image/cpu/decode_image.h @@ -8,7 +8,8 @@ namespace image { C10_EXPORT torch::Tensor decode_image( const torch::Tensor& data, - ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED); + ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED, + bool apply_exif_orientation = false); } // namespace image } // namespace vision diff --git a/torchvision/csrc/io/image/cpu/decode_jpeg.cpp b/torchvision/csrc/io/image/cpu/decode_jpeg.cpp index 63a4e5b42ec..e8a52c69593 100644 --- a/torchvision/csrc/io/image/cpu/decode_jpeg.cpp +++ b/torchvision/csrc/io/image/cpu/decode_jpeg.cpp @@ -1,5 +1,6 @@ #include "decode_jpeg.h" #include "common_jpeg.h" +#include "exif.h" namespace vision { namespace image { @@ -12,6 +13,7 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) { #else using namespace detail; +using namespace exif_private; namespace { @@ -65,6 +67,8 @@ static void torch_jpeg_set_source_mgr( src->len = len; src->pub.bytes_in_buffer = len; src->pub.next_input_byte = src->data; + + jpeg_save_markers(cinfo, APP1, 0xffff); } inline unsigned char clamped_cmyk_rgb_convert( @@ -121,7 +125,10 @@ void convert_line_cmyk_to_gray( } // namespace -torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) { +torch::Tensor decode_jpeg( + const torch::Tensor& data, + ImageReadMode mode, + bool apply_exif_orientation) { C10_LOG_API_USAGE_ONCE( "torchvision.csrc.io.image.cpu.decode_jpeg.decode_jpeg"); // Check that the input tensor dtype is uint8 @@ -191,6 +198,54 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) { jpeg_calc_output_dimensions(&cinfo); } + int exif_orientation = 0; + if (apply_exif_orientation) { + // Check for Exif marker APP1 + jpeg_saved_marker_ptr exif_marker = 0; + jpeg_saved_marker_ptr cmarker = cinfo.marker_list; + while (cmarker && exif_marker == 0) { + if (cmarker->marker == APP1) { + exif_marker = cmarker; + } + cmarker = cmarker->next; + } + + if (exif_marker) { + // Code below is inspired from OpenCV + // https://github.dev/opencv/opencv/blob/097891e311fae1d8354eb092a0fd0171e630d78c/modules/modules/imgcodecs/src/exif.cpp + + // Bytes from Exif size field to the first TIFF header + constexpr size_t start_offset = 6; + if (exif_marker->data_length > start_offset) { + auto* exif_data_ptr = exif_marker->data + start_offset; + auto size = exif_marker->data_length - start_offset; + std::vector exif_data_vec( + exif_data_ptr, exif_data_ptr + size); + + auto endianness = get_endianness(exif_data_vec); + + // Checking whether Tag Mark (0x002A) correspond to one contained in the + // Jpeg file + uint16_t tag_mark = get_uint16(exif_data_vec, endianness, 2); + if (tag_mark == REQ_EXIF_TAG_MARK) { + auto offset = get_uint32(exif_data_vec, endianness, 4); + size_t num_entry = get_uint16(exif_data_vec, endianness, offset); + offset += 2; // go to start of tag fields + constexpr size_t tiff_field_size = 12; + for (size_t entry = 0; entry < num_entry; entry++) { + // Here we just search for orientation tag and parse it + auto tag_num = get_uint16(exif_data_vec, endianness, offset); + if (tag_num == ORIENTATION_EXIF_TAG) { + exif_orientation = + get_uint16(exif_data_vec, endianness, offset + 8); + } + offset += tiff_field_size; + } + } + } + } + } + jpeg_start_decompress(&cinfo); int height = cinfo.output_height; @@ -227,7 +282,12 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) { jpeg_finish_decompress(&cinfo); jpeg_destroy_decompress(&cinfo); - return tensor.permute({2, 0, 1}); + auto output = tensor.permute({2, 0, 1}); + + if (apply_exif_orientation) { + return exif_orientation_transform(output, exif_orientation); + } + return output; } #endif // #if !JPEG_FOUND diff --git a/torchvision/csrc/io/image/cpu/decode_jpeg.h b/torchvision/csrc/io/image/cpu/decode_jpeg.h index 254e94680b6..e0c9a24c846 100644 --- a/torchvision/csrc/io/image/cpu/decode_jpeg.h +++ b/torchvision/csrc/io/image/cpu/decode_jpeg.h @@ -8,7 +8,8 @@ namespace image { C10_EXPORT torch::Tensor decode_jpeg( const torch::Tensor& data, - ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED); + ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED, + bool apply_exif_orientation = false); C10_EXPORT int64_t _jpeg_version(); C10_EXPORT bool _is_compiled_against_turbo(); diff --git a/torchvision/csrc/io/image/cpu/exif.h b/torchvision/csrc/io/image/cpu/exif.h new file mode 100644 index 00000000000..4a0b9b843f7 --- /dev/null +++ b/torchvision/csrc/io/image/cpu/exif.h @@ -0,0 +1,99 @@ +#pragma once +#include + +namespace vision { +namespace image { +namespace exif_private { + +constexpr uint16_t APP1 = 0xe1; +constexpr uint16_t ENDIANNESS_INTEL = 0x49; +constexpr uint16_t ENDIANNESS_MOTO = 0x4d; +constexpr uint16_t REQ_EXIF_TAG_MARK = 0x2a; +constexpr uint16_t ORIENTATION_EXIF_TAG = 0x0112; +constexpr uint16_t INCORRECT_TAG = -1; + +// Functions in this module are taken from OpenCV +// https://github.com/opencv/opencv/blob/097891e311fae1d8354eb092a0fd0171e630d78c/modules/modules/imgcodecs/src/exif.cpp +inline uint16_t get_endianness(const std::vector& exif_data) { + if ((exif_data.size() < 1) || + (exif_data.size() > 1 && exif_data[0] != exif_data[1])) { + return 0; + } + if (exif_data[0] == 'I') { + return ENDIANNESS_INTEL; + } + if (exif_data[0] == 'M') { + return ENDIANNESS_MOTO; + } + return 0; +} + +inline uint16_t get_uint16( + const std::vector& exif_data, + uint16_t endianness, + const size_t offset) { + if (offset + 1 >= exif_data.size()) { + return INCORRECT_TAG; + } + + if (endianness == ENDIANNESS_INTEL) { + return exif_data[offset] + (exif_data[offset + 1] << 8); + } + return (exif_data[offset] << 8) + exif_data[offset + 1]; +} + +inline uint32_t get_uint32( + const std::vector& exif_data, + uint16_t endianness, + const size_t offset) { + if (offset + 3 >= exif_data.size()) { + return INCORRECT_TAG; + } + + if (endianness == ENDIANNESS_INTEL) { + return exif_data[offset] + (exif_data[offset + 1] << 8) + + (exif_data[offset + 2] << 16) + (exif_data[offset + 3] << 24); + } + return (exif_data[offset] << 24) + (exif_data[offset + 1] << 16) + + (exif_data[offset + 2] << 8) + exif_data[offset + 3]; +} + +constexpr uint16_t IMAGE_ORIENTATION_TL = 1; // normal orientation +constexpr uint16_t IMAGE_ORIENTATION_TR = 2; // needs horizontal flip +constexpr uint16_t IMAGE_ORIENTATION_BR = 3; // needs 180 rotation +constexpr uint16_t IMAGE_ORIENTATION_BL = 4; // needs vertical flip +constexpr uint16_t IMAGE_ORIENTATION_LT = + 5; // mirrored horizontal & rotate 270 CW +constexpr uint16_t IMAGE_ORIENTATION_RT = 6; // rotate 90 CW +constexpr uint16_t IMAGE_ORIENTATION_RB = + 7; // mirrored horizontal & rotate 90 CW +constexpr uint16_t IMAGE_ORIENTATION_LB = 8; // needs 270 CW rotation + +inline torch::Tensor exif_orientation_transform( + const torch::Tensor& image, + int orientation) { + if (orientation == IMAGE_ORIENTATION_TL) { + return image; + } else if (orientation == IMAGE_ORIENTATION_TR) { + return image.flip(-1); + } else if (orientation == IMAGE_ORIENTATION_BR) { + // needs 180 rotation equivalent to + // flip both horizontally and vertically + return image.flip({-2, -1}); + } else if (orientation == IMAGE_ORIENTATION_BL) { + return image.flip(-2); + } else if (orientation == IMAGE_ORIENTATION_LT) { + return image.transpose(-1, -2); + } else if (orientation == IMAGE_ORIENTATION_RT) { + return image.transpose(-1, -2).flip(-1); + } else if (orientation == IMAGE_ORIENTATION_RB) { + return image.transpose(-1, -2).flip({-2, -1}); + } else if (orientation == IMAGE_ORIENTATION_LB) { + return image.transpose(-1, -2).flip(-2); + } + return image; +} + +} // namespace exif_private +} // namespace image +} // namespace vision diff --git a/torchvision/io/image.py b/torchvision/io/image.py index 900ac4e36d2..8dd61d0d469 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -212,7 +212,9 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75): write_file(filename, output) -def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor: +def decode_image( + input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, apply_exif_orientation: bool = False +) -> torch.Tensor: """ Detects whether an image is a JPEG or PNG and performs the appropriate operation to decode the image into a 3 dimensional RGB or grayscale Tensor. @@ -227,17 +229,21 @@ def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN Default: ``ImageReadMode.UNCHANGED``. See ``ImageReadMode`` class for more information on various available modes. + apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor. + Default: False. Only implemented for JPEG format Returns: output (Tensor[image_channels, image_height, image_width]) """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(decode_image) - output = torch.ops.image.decode_image(input, mode.value) + output = torch.ops.image.decode_image(input, mode.value, apply_exif_orientation) return output -def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor: +def read_image( + path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED, apply_exif_orientation: bool = False +) -> torch.Tensor: """ Reads a JPEG or PNG image into a 3 dimensional RGB or grayscale Tensor. Optionally converts the image to the desired format. @@ -249,6 +255,8 @@ def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torc Default: ``ImageReadMode.UNCHANGED``. See ``ImageReadMode`` class for more information on various available modes. + apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor. + Default: False. Only implemented for JPEG format Returns: output (Tensor[image_channels, image_height, image_width]) @@ -256,7 +264,7 @@ def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torc if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(read_image) data = read_file(path) - return decode_image(data, mode) + return decode_image(data, mode, apply_exif_orientation=apply_exif_orientation) def _read_png_16(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor: