forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
nested_int.cpp
105 lines (95 loc) · 3.24 KB
/
nested_int.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#include <gtest/gtest.h>
#include <ATen/core/NestedIntSymNodeImpl.h>
#include <c10/core/SymInt.h>
#include <c10/core/SymNodeImpl.h>
#include <torch/torch.h>
#include <test/cpp/api/support.h>
TEST(NestedIntTest, Comparisons) {
auto a = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 1)));
auto b = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 1)));
auto c = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(2, 1)));
auto d = c10::SymInt(3);
ASSERT_TRUE(a == a);
ASSERT_TRUE(a == b);
ASSERT_FALSE(a != a);
ASSERT_FALSE(a != b);
ASSERT_FALSE(a == c);
ASSERT_TRUE(a != c);
ASSERT_FALSE(a == d);
ASSERT_TRUE(a != d);
ASSERT_FALSE(d == a);
ASSERT_TRUE(d != a);
// ge
ASSERT_TRUE(a >= a);
ASSERT_TRUE(a >= b);
ASSERT_TRUE(b >= a);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(a >= c), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(c >= a), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(c >= 3), c10::Error);
ASSERT_TRUE(c >= 2);
ASSERT_TRUE(c >= 1);
ASSERT_FALSE(1 >= c);
// lt
ASSERT_FALSE(a < a);
ASSERT_FALSE(a < b);
ASSERT_FALSE(b < a);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(a < c), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(c < a), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(3 < a), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(2 < a), c10::Error);
ASSERT_TRUE(1 < a);
// le
ASSERT_TRUE(a <= a);
ASSERT_TRUE(b <= a);
ASSERT_TRUE(a <= b);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(a <= c), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(c <= a), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(3 <= c), c10::Error);
ASSERT_TRUE(2 <= c);
ASSERT_TRUE(1 <= c);
ASSERT_FALSE(c <= 1);
// gt
ASSERT_FALSE(a > a);
ASSERT_FALSE(b > a);
ASSERT_FALSE(a > b);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(a > c), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(c > a), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(a > 3), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(a > 2), c10::Error);
ASSERT_TRUE(a > 1);
}
TEST(NestedIntTest, WithFactor) {
auto a = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 5)));
auto b = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 10)));
// eq
ASSERT_FALSE(a == b);
ASSERT_FALSE(a >= b);
ASSERT_TRUE(b >= a);
ASSERT_TRUE(a <= b);
ASSERT_FALSE(b <= a);
// ne
ASSERT_TRUE(a != b);
// mul
ASSERT_TRUE(a * 2 == b);
ASSERT_TRUE(a * 3 >= b);
ASSERT_TRUE(a * 2 == 2 * a);
}