Skip to content

Commit

Permalink
Added support for EXIF orientation transform in read_image for JPEG
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Feb 23, 2024
1 parent b1123cf commit a40be0a
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 11 deletions.
21 changes: 20 additions & 1 deletion test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import torchvision.transforms.functional as F
from common_utils import assert_equal, needs_cuda
from PIL import __version__ as PILLOW_VERSION, Image
from PIL import __version__ as PILLOW_VERSION, Image, ImageOps
from torchvision.io.image import (
_read_png_16,
decode_image,
Expand Down Expand Up @@ -100,6 +100,25 @@ def test_decode_jpeg(img_path, pil_mode, mode):
assert abs_mean_diff < 2


@pytest.mark.parametrize("orientation", [1, 2, 3, 4, 5, 6, 7, 8, 0])
def test_decode_jpeg_with_exif_orientation(tmpdir, orientation):
fp = os.path.join(tmpdir, f"exif_oriented_{orientation}.jpg")
t = torch.randint(0, 256, size=(3, 256, 257), dtype=torch.uint8)
im = F.to_pil_image(t)
exif = im.getexif()
exif[274] = orientation # set exif orientation
im.save(fp, "JPEG", exif=exif.tobytes())

data = read_file(fp)
output = decode_image(data, apply_exif_orientation=True)

pimg = Image.open(fp)
pimg = ImageOps.exif_transpose(pimg)

expected = F.pil_to_tensor(pimg)
torch.testing.assert_close(expected, output)


def test_decode_jpeg_errors():
with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))
Expand Down
10 changes: 8 additions & 2 deletions torchvision/csrc/io/image/cpu/decode_image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
namespace vision {
namespace image {

torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) {
torch::Tensor decode_image(
const torch::Tensor& data,
ImageReadMode mode,
bool apply_exif_orientation) {
// Check that tensor is a CPU tensor
TORCH_CHECK(data.device() == torch::kCPU, "Expected a CPU tensor");
// Check that the input tensor dtype is uint8
Expand All @@ -22,8 +25,11 @@ torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) {
const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG"

if (memcmp(jpeg_signature, datap, 3) == 0) {
return decode_jpeg(data, mode);
return decode_jpeg(data, mode, apply_exif_orientation);
} else if (memcmp(png_signature, datap, 4) == 0) {
TORCH_CHECK(
!apply_exif_orientation,
"Unsupported option apply_exif_orientation=true for PNG")
return decode_png(data, mode);
} else {
TORCH_CHECK(
Expand Down
3 changes: 2 additions & 1 deletion torchvision/csrc/io/image/cpu/decode_image.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ namespace image {

C10_EXPORT torch::Tensor decode_image(
const torch::Tensor& data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED,
bool apply_exif_orientation = false);

} // namespace image
} // namespace vision
64 changes: 62 additions & 2 deletions torchvision/csrc/io/image/cpu/decode_jpeg.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "decode_jpeg.h"
#include "common_jpeg.h"
#include "exif.h"

namespace vision {
namespace image {
Expand All @@ -12,6 +13,7 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
#else

using namespace detail;
using namespace exif_private;

namespace {

Expand Down Expand Up @@ -65,6 +67,8 @@ static void torch_jpeg_set_source_mgr(
src->len = len;
src->pub.bytes_in_buffer = len;
src->pub.next_input_byte = src->data;

jpeg_save_markers(cinfo, APP1, 0xffff);
}

inline unsigned char clamped_cmyk_rgb_convert(
Expand Down Expand Up @@ -121,7 +125,10 @@ void convert_line_cmyk_to_gray(

} // namespace

torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
torch::Tensor decode_jpeg(
const torch::Tensor& data,
ImageReadMode mode,
bool apply_exif_orientation) {
C10_LOG_API_USAGE_ONCE(
"torchvision.csrc.io.image.cpu.decode_jpeg.decode_jpeg");
// Check that the input tensor dtype is uint8
Expand Down Expand Up @@ -191,6 +198,54 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
jpeg_calc_output_dimensions(&cinfo);
}

int exif_orientation = 0;
if (apply_exif_orientation) {
// Check for Exif marker APP1
jpeg_saved_marker_ptr exif_marker = 0;
jpeg_saved_marker_ptr cmarker = cinfo.marker_list;
while (cmarker && exif_marker == 0) {
if (cmarker->marker == APP1) {
exif_marker = cmarker;
}
cmarker = cmarker->next;
}

if (exif_marker) {
// Code below is inspired from OpenCV
// https://github.dev/opencv/opencv/blob/097891e311fae1d8354eb092a0fd0171e630d78c/modules/modules/imgcodecs/src/exif.cpp

// Bytes from Exif size field to the first TIFF header
constexpr size_t start_offset = 6;
if (exif_marker->data_length > start_offset) {
auto* exif_data_ptr = exif_marker->data + start_offset;
auto size = exif_marker->data_length - start_offset;
std::vector<unsigned char> exif_data_vec(
exif_data_ptr, exif_data_ptr + size);

auto endianness = get_endianness(exif_data_vec);

// Checking whether Tag Mark (0x002A) correspond to one contained in the
// Jpeg file
uint16_t tag_mark = get_uint16(exif_data_vec, endianness, 2);
if (tag_mark == REQ_EXIF_TAG_MARK) {
auto offset = get_uint32(exif_data_vec, endianness, 4);
size_t num_entry = get_uint16(exif_data_vec, endianness, offset);
offset += 2; // go to start of tag fields
constexpr size_t tiff_field_size = 12;
for (size_t entry = 0; entry < num_entry; entry++) {
// Here we just search for orientation tag and parse it
auto tag_num = get_uint16(exif_data_vec, endianness, offset);
if (tag_num == ORIENTATION_EXIF_TAG) {
exif_orientation =
get_uint16(exif_data_vec, endianness, offset + 8);
}
offset += tiff_field_size;
}
}
}
}
}

jpeg_start_decompress(&cinfo);

int height = cinfo.output_height;
Expand Down Expand Up @@ -227,7 +282,12 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {

jpeg_finish_decompress(&cinfo);
jpeg_destroy_decompress(&cinfo);
return tensor.permute({2, 0, 1});
auto output = tensor.permute({2, 0, 1});

if (apply_exif_orientation) {
return exif_orientation_transform(output, exif_orientation);
}
return output;
}
#endif // #if !JPEG_FOUND

Expand Down
3 changes: 2 additions & 1 deletion torchvision/csrc/io/image/cpu/decode_jpeg.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ namespace image {

C10_EXPORT torch::Tensor decode_jpeg(
const torch::Tensor& data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED,
bool apply_exif_orientation = false);

C10_EXPORT int64_t _jpeg_version();
C10_EXPORT bool _is_compiled_against_turbo();
Expand Down
99 changes: 99 additions & 0 deletions torchvision/csrc/io/image/cpu/exif.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#pragma once
#include <torch/types.h>

namespace vision {
namespace image {
namespace exif_private {

constexpr uint16_t APP1 = 0xe1;
constexpr uint16_t ENDIANNESS_INTEL = 0x49;
constexpr uint16_t ENDIANNESS_MOTO = 0x4d;
constexpr uint16_t REQ_EXIF_TAG_MARK = 0x2a;
constexpr uint16_t ORIENTATION_EXIF_TAG = 0x0112;
constexpr uint16_t INCORRECT_TAG = -1;

// Functions in this module are taken from OpenCV
// https://github.com/opencv/opencv/blob/097891e311fae1d8354eb092a0fd0171e630d78c/modules/modules/imgcodecs/src/exif.cpp
inline uint16_t get_endianness(const std::vector<unsigned char>& exif_data) {
if ((exif_data.size() < 1) ||
(exif_data.size() > 1 && exif_data[0] != exif_data[1])) {
return 0;
}
if (exif_data[0] == 'I') {
return ENDIANNESS_INTEL;
}
if (exif_data[0] == 'M') {
return ENDIANNESS_MOTO;
}
return 0;
}

inline uint16_t get_uint16(
const std::vector<unsigned char>& exif_data,
uint16_t endianness,
const size_t offset) {
if (offset + 1 >= exif_data.size()) {
return INCORRECT_TAG;
}

if (endianness == ENDIANNESS_INTEL) {
return exif_data[offset] + (exif_data[offset + 1] << 8);
}
return (exif_data[offset] << 8) + exif_data[offset + 1];
}

inline uint32_t get_uint32(
const std::vector<unsigned char>& exif_data,
uint16_t endianness,
const size_t offset) {
if (offset + 3 >= exif_data.size()) {
return INCORRECT_TAG;
}

if (endianness == ENDIANNESS_INTEL) {
return exif_data[offset] + (exif_data[offset + 1] << 8) +
(exif_data[offset + 2] << 16) + (exif_data[offset + 3] << 24);
}
return (exif_data[offset] << 24) + (exif_data[offset + 1] << 16) +
(exif_data[offset + 2] << 8) + exif_data[offset + 3];
}

constexpr uint16_t IMAGE_ORIENTATION_TL = 1; // normal orientation
constexpr uint16_t IMAGE_ORIENTATION_TR = 2; // needs horizontal flip
constexpr uint16_t IMAGE_ORIENTATION_BR = 3; // needs 180 rotation
constexpr uint16_t IMAGE_ORIENTATION_BL = 4; // needs vertical flip
constexpr uint16_t IMAGE_ORIENTATION_LT =
5; // mirrored horizontal & rotate 270 CW
constexpr uint16_t IMAGE_ORIENTATION_RT = 6; // rotate 90 CW
constexpr uint16_t IMAGE_ORIENTATION_RB =
7; // mirrored horizontal & rotate 90 CW
constexpr uint16_t IMAGE_ORIENTATION_LB = 8; // needs 270 CW rotation

inline torch::Tensor exif_orientation_transform(
const torch::Tensor& image,
int orientation) {
if (orientation == IMAGE_ORIENTATION_TL) {
return image;
} else if (orientation == IMAGE_ORIENTATION_TR) {
return image.flip(-1);
} else if (orientation == IMAGE_ORIENTATION_BR) {
// needs 180 rotation equivalent to
// flip both horizontally and vertically
return image.flip({-2, -1});
} else if (orientation == IMAGE_ORIENTATION_BL) {
return image.flip(-2);
} else if (orientation == IMAGE_ORIENTATION_LT) {
return image.transpose(-1, -2);
} else if (orientation == IMAGE_ORIENTATION_RT) {
return image.transpose(-1, -2).flip(-1);
} else if (orientation == IMAGE_ORIENTATION_RB) {
return image.transpose(-1, -2).flip({-2, -1});
} else if (orientation == IMAGE_ORIENTATION_LB) {
return image.transpose(-1, -2).flip(-2);
}
return image;
}

} // namespace exif_private
} // namespace image
} // namespace vision
16 changes: 12 additions & 4 deletions torchvision/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,9 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
write_file(filename, output)


def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
def decode_image(
input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, apply_exif_orientation: bool = False
) -> torch.Tensor:
"""
Detects whether an image is a JPEG or PNG and performs the appropriate
operation to decode the image into a 3 dimensional RGB or grayscale Tensor.
Expand All @@ -227,17 +229,21 @@ def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN
Default: ``ImageReadMode.UNCHANGED``.
See ``ImageReadMode`` class for more information on various
available modes.
apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
Default: False. Only implemented for JPEG format
Returns:
output (Tensor[image_channels, image_height, image_width])
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(decode_image)
output = torch.ops.image.decode_image(input, mode.value)
output = torch.ops.image.decode_image(input, mode.value, apply_exif_orientation)
return output


def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
def read_image(
path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED, apply_exif_orientation: bool = False
) -> torch.Tensor:
"""
Reads a JPEG or PNG image into a 3 dimensional RGB or grayscale Tensor.
Optionally converts the image to the desired format.
Expand All @@ -249,14 +255,16 @@ def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torc
Default: ``ImageReadMode.UNCHANGED``.
See ``ImageReadMode`` class for more information on various
available modes.
apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
Default: False. Only implemented for JPEG format
Returns:
output (Tensor[image_channels, image_height, image_width])
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(read_image)
data = read_file(path)
return decode_image(data, mode)
return decode_image(data, mode, apply_exif_orientation=apply_exif_orientation)


def _read_png_16(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
Expand Down

0 comments on commit a40be0a

Please sign in to comment.