forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTensorTransformations.cpp
95 lines (77 loc) · 2.98 KB
/
TensorTransformations.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#include "ATen/native/TensorTransformations.h"
#include <ATen/NativeFunctions.h>
#include <c10/util/Exception.h>
#include <algorithm>
#include <vector>
namespace at {
namespace native {
Tensor flip_cpu(const Tensor& self, IntList dims) {
const int64_t total_dims = self.dim(), flip_dims_size = dims.size();
flip_check_errors(total_dims, flip_dims_size, dims);
auto flip_dims_v = dims.vec();
wrap_all_dims(flip_dims_v, total_dims);
std::sort(flip_dims_v.begin(), flip_dims_v.end());
auto final_indices = std::vector<at::Tensor>(total_dims);
auto indices = std::vector<at::Tensor>(flip_dims_size);
for (int64_t i = 0; i < flip_dims_size; i++) {
indices[i] = at::arange(self.size(flip_dims_v[i]) - 1, -1, -1, self.type().toScalarType(at::kLong));
// creates a meshgrid
auto temp = std::vector<int64_t>(flip_dims_size, 1);
temp[i] = indices[i].size(0);
indices[i] = indices[i].view(IntList(temp));
final_indices[flip_dims_v[i]] = indices[i];
}
// check if distance between two flip dims >= 2, where permute of output tensor is needed,
// because the advanced indexing puts all non-consecutive indices in the beginning of the tensor
bool to_permute = false;
int64_t first = flip_dims_v[0], second = flip_dims_v[0];
for (int64_t i = 1; i < flip_dims_size; i++) {
second = flip_dims_v[i];
if (second - first >= 2) {
to_permute = true;
break;
}
first = second;
}
if (to_permute) {
// permute output tensor
auto permute_order = std::vector<int64_t>(flip_dims_v);
for (int64_t i = 0; i < total_dims; i++) {
if (std::find(flip_dims_v.begin(), flip_dims_v.end(), i) == flip_dims_v.end()) {
permute_order.emplace_back(i);
}
}
auto out_tensor = self.index(TensorList(final_indices));
return out_tensor.permute(IntList(permute_order));
}
auto out_tensor = self.index(TensorList(final_indices));
return out_tensor;
}
Tensor rot90(const Tensor& self, int64_t k, IntList dims) {
const int64_t total_dims = self.dim(), total_rot_dims = dims.size();
AT_CHECK(total_rot_dims == 2,
"expected total rotation dims == 2, but got dims = ", total_rot_dims);
AT_CHECK(total_dims >= 2,
"expected total dims >= 2, but got total dims = ", total_dims);
AT_CHECK(dims[0] != dims[1] && std::abs(dims[0] - dims[1]) != total_dims,
"expected rotation dims to be different, but got dim0 = ", dims[0],
" and dim1 = ", dims[1]);
// check range of dims
AT_CHECK(dims[0] < total_dims && dims[0] >= -total_dims,
"Rotation dim0 out of range, dim0 = ", dims[0]);
AT_CHECK(dims[1] < total_dims && dims[1] >= -total_dims,
"Rotation dim1 out of range, dim1 = ", dims[1]);
// handle modulo with negative k
k = (4 + (k % 4)) % 4;
switch(k) {
case 1:
return self.flip({dims[1]}).transpose_(dims[0], dims[1]);
case 2:
return self.flip(dims);
case 3:
return self.flip({dims[0]}).transpose_(dims[0], dims[1]);
default:
return self.clone();
}
}
}} // namespace at::native