-
Notifications
You must be signed in to change notification settings - Fork 36
/
patched_torchdistX_9c1b9f.patch
349 lines (321 loc) · 11.7 KB
/
patched_torchdistX_9c1b9f.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
diff --git a/setup.py b/setup.py
index a4102b5..b468cb1 100644
--- a/setup.py
+++ b/setup.py
@@ -173,7 +173,6 @@ def main() -> None:
zip_safe=False,
# Since PyTorch does not offer ABI compatibility we have to make sure
# that we use the same version that was used at build time.
- install_requires=[f"torch=={torch.__version__}"],
classifiers=[
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
diff --git a/src/cc/torchdistx/deferred_init.cc b/src/cc/torchdistx/deferred_init.cc
index 961b638..0a29d9f 100644
--- a/src/cc/torchdistx/deferred_init.cc
+++ b/src/cc/torchdistx/deferred_init.cc
@@ -173,6 +173,7 @@ class Op {
}
void materialize();
+ void materializeWithShape(c10::IntArrayRef shape, const c10::optional<c10::Device> device);
std::size_t num_outputs() const noexcept {
return num_outputs_;
@@ -220,7 +221,6 @@ Op Op::fromOperatorHandle(const OperatorHandle& handle, Stack s) {
};
const FunctionSchema& shm = handle.schema();
-
return Op{shm.name(), std::move(fn), shm.arguments().size(), shm.returns().size(), std::move(s)};
}
@@ -271,6 +271,44 @@ void Op::materialize() {
materialized_ = true;
}
+void Op::materializeWithShape(c10::IntArrayRef shape, const c10::optional<c10::Device> device) {
+ if (materialized_) {
+ return;
+ }
+
+ {
+ ThreadLocalStateGuard state_guard{*tls_};
+
+ auto replace_first_shape = [&](c10::IntArrayRef sp){
+ IValue local_shape(sp);
+ stack_[0] = local_shape;
+ };
+
+ std::vector<std::string> op_white_list{"aten::randn", "aten::rand", "aten::empty", "aten::ones", "aten::zeros", "aten::full" };
+
+ if (std::find(op_white_list.begin(),op_white_list.end(), name()) != op_white_list.end()){
+ // if the op is operator
+ replace_first_shape(shape);
+ }
+
+ if(device.has_value()){ // set target device
+ for (size_t i = 0 ; i < stack_.size(); i++){
+ if(stack_[i].isDevice()){
+ stack_[i] = IValue(device.value());
+ }
+ }
+ }
+
+ fn_(stack_);
+ }
+
+ fn_ = nullptr;
+
+ tls_ = nullopt;
+
+ materialized_ = true;
+}
+
const Tensor& Op::getOutput(std::size_t idx) const noexcept {
const Tensor* opt_out = nullptr;
@@ -343,6 +381,8 @@ class OpNode {
// Materializes the operation held by this node along with all the operations
// in its recorded call stack.
void materialize();
+ // with changed shape
+ void materializeWithShape(c10::IntArrayRef shape, c10::optional<c10::Device> device);
private:
void buildCallStack();
@@ -527,6 +567,30 @@ void OpNode::materialize() {
call_stack_.clear();
}
+void OpNode::materializeWithShape(c10::IntArrayRef shape, const c10::optional<c10::Device> device) {
+ // Do not try to shortcut this function by checking if the node is already
+ // materialized. A later in-place operation can still change the output of
+ // this node.
+
+ buildCallStack();
+
+ for (OpNode* node : call_stack_) {
+ if (node->op_.materialized()) {
+ continue;
+ }
+
+ node->materializeArguments();
+
+ node->op_.materializeWithShape(shape, device);
+
+ // Make sure that we deallocate parts of the operation graph that are not
+ // needed anymore.
+ node->detachDependencies();
+ }
+
+ call_stack_.clear();
+}
+
void OpNode::buildCallStack() {
OpNode* last_node = getLastInPlaceOpNode();
@@ -728,6 +792,24 @@ Tensor materialize(const Tensor& fake) {
return out;
}
+Tensor materialize_with_shape(const Tensor& fake, c10::IntArrayRef shape, const c10::optional<c10::Device> device) {
+ TensorRecord& record = getTensorRecord(fake);
+
+ const OpOutputDescriptor& output_desc = record.output_descriptor();
+
+ output_desc.node()->materializeWithShape(shape, device);
+
+ Tensor out = output_desc.node()->op().getOutput(output_desc.output_index());
+
+ // Unfortunately there is no way for us to track calls to `requires_grad_()`,
+ // so instead we explicitly set `requires_grad` after materialization.
+ if (fake.is_leaf() && fake.requires_grad()) {
+ out.set_requires_grad(true);
+ }
+
+ return out;
+}
+
// The catch-all handler for the `DeferredInit` dispatch key.
class DeferredInitHandler {
public:
@@ -1032,6 +1114,12 @@ class ProxyVariableHooks : public VariableHooksInterface {
inner_->requires_grad_(self, value);
}
+ void basic_autograd_not_implemented_fallback(const c10::OperatorHandle& op,
+ c10::DispatchKeySet dispatch_keys,
+ torch::jit::Stack* stack) const override {
+ inner_->basic_autograd_not_implemented_fallback(op, dispatch_keys, stack);
+ }
+
VariableHooksInterface* inner() noexcept {
return inner_;
}
@@ -1164,6 +1252,7 @@ bool canMaterialize(const Tensor& tensor) noexcept {
return isFake(tensor) && unsafeAsFake(tensor).hasData(DispatchKey::DeferredInit);
}
+
Tensor materializeTensor(const Tensor& tensor) {
if (canMaterialize(tensor)) {
return detail::materialize(tensor);
@@ -1172,4 +1261,24 @@ Tensor materializeTensor(const Tensor& tensor) {
}
}
+Tensor materializeTensorWithLocalShape(const at::Tensor& tensor, c10::IntArrayRef shape, const c10::optional<c10::Device> device){
+ if (canMaterialize(tensor)) {
+ return detail::materialize_with_shape(tensor, shape, device);
+ } else {
+ return tensor;
+ }
+}
+
+bool isGenByRandomOp(const Tensor& tensor) noexcept{
+ if (canMaterialize(tensor)) {
+ detail::TensorRecord& record = detail::getTensorRecord(tensor);
+ const detail::OpOutputDescriptor& output_desc = record.output_descriptor();
+ auto name = output_desc.node()->op().name();
+ std::vector<std::string> op_white_list{"aten::randn", "aten::rand"};
+ return std::find(op_white_list.begin(),op_white_list.end(), name) != op_white_list.end();
+ }else{
+ return false;
+ }
+}
+
} // namespace torchdistx
diff --git a/src/cc/torchdistx/deferred_init.h b/src/cc/torchdistx/deferred_init.h
index afda4cb..4c654cf 100644
--- a/src/cc/torchdistx/deferred_init.h
+++ b/src/cc/torchdistx/deferred_init.h
@@ -8,6 +8,8 @@
#include <c10/core/DispatchKey.h>
#include <c10/core/impl/LocalDispatchKeySet.h>
+#include <c10/core/SymIntArrayRef.h>
+#include <c10/core/Device.h>
#include "macros.h"
@@ -27,9 +29,10 @@ TDX_API void leaveDeferredInit() noexcept;
// Indicates whether `tensor` has been constructed in a deferred-init context.
TDX_API bool canMaterialize(const at::Tensor& tensor) noexcept;
-
+TDX_API bool isGenByRandomOp(const at::Tensor& tensor) noexcept;
// Materializes `tensor`.
TDX_API at::Tensor materializeTensor(const at::Tensor& tensor);
+TDX_API at::Tensor materializeTensorWithLocalShape(const at::Tensor& tensor, c10::IntArrayRef shape, const c10::optional<c10::Device> device = {});
// Temporarily disables deferred-init.
class TDX_API NoDeferredInit {
diff --git a/src/python/torchdistx/_C.pyi b/src/python/torchdistx/_C.pyi
index 82c8421..9e4462a 100644
--- a/src/python/torchdistx/_C.pyi
+++ b/src/python/torchdistx/_C.pyi
@@ -5,12 +5,17 @@
# LICENSE file in the root directory of this source tree.
import torch
+from torch.types import _int, SymInt, _device
+from collections import Sequence
+from typing import Union, Optional
def enter_deferred_init() -> None: ...
def leave_deferred_init() -> None: ...
def enter_fake_mode(fake_mode: bool) -> None: ...
def leave_fake_mode() -> None: ...
def is_fake(tensor: torch.Tensor) -> bool: ...
+def is_gen_by_random_op(tensor: torch.Tensor) -> bool: ...
def can_materialize(tensor: torch.Tensor) -> bool: ...
def materialize_tensor(tensor: torch.Tensor) -> torch.Tensor: ...
+def materialize_tensor_with_local_shape(tensor: torch.Tensor, shape: Sequence[Union[_int, SymInt]], device: Optional[Union[_device, str, None]] = None) -> torch.Tensor: ...
def meta_like(fake: torch.Tensor) -> torch.Tensor: ...
diff --git a/src/python/torchdistx/_C/deferred_init.cc b/src/python/torchdistx/_C/deferred_init.cc
index 7eb73b8..d810544 100644
--- a/src/python/torchdistx/_C/deferred_init.cc
+++ b/src/python/torchdistx/_C/deferred_init.cc
@@ -8,6 +8,7 @@
#include <ATen/Tensor.h>
#include <c10/core/TensorImpl.h>
+#include <c10/core/SymIntArrayRef.h>
#include <torch/csrc/PyInterpreter.h>
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/utils/pybind.h>
@@ -94,14 +95,58 @@ py::object materializeVariable(const py::object& var) {
return makeVariable(Py_TYPE(naked_var), std::move(materialized_data));
}
+
+// Materializing a tensor in Python requires an extra step. We need to ensure
+// that the materialized tensor has the same Python class (e.g. `Variable` or
+// `Parameter`) as the original tensor.
+// and with dtensor case we need to change the parallized tensor shape
+py::object materializeVariableWithLocalShape(const py::object& var, const py::object &shape, const c10::optional<c10::Device> device) {
+ PyObject* naked_var = var.ptr();
+ auto c_shape = shape.cast<std::vector<int64_t>>();
+
+ if (!THPVariable_Check(naked_var)) {
+ throw TypeError{"`var` has to be a `Variable`, but got `%s`.", Py_TYPE(naked_var)->tp_name};
+ }
+
+ const Tensor& data = THPVariable_Unpack(naked_var);
+
+ auto materialize = [=](const Tensor& tensor, c10::IntArrayRef sp) {
+ py::gil_scoped_release guard{};
+
+ return materializeTensorWithLocalShape(tensor, sp, device);
+ };
+
+ Tensor materialized_data = materialize(data, at::IntArrayRef(c_shape));
+
+ // Check if we have really materialized `data`. Materializing a regular tensor
+ // is a no-op, so we can simply return.
+ if (materialized_data.is_same(data)) {
+ return var;
+ }
+
+ // We might have already materialized `data`. Make sure that we preserve its
+ // identity on the Python side and avoid creating a new Python tensor.
+ c10::optional<PyObject*> opt_materialized_var =
+ materialized_data.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter());
+ if (opt_materialized_var.has_value()) {
+ return py::reinterpret_borrow<py::object>(*opt_materialized_var);
+ }
+
+ // Otherwise ensure that our materialized tensor has the same Python class as
+ // the original tensor.
+ return makeVariable(Py_TYPE(naked_var), std::move(materialized_data));
+}
+
+
} // namespace
void initDeferredInitFunctions(py::module& m) {
m.def("enter_deferred_init", enterDeferredInit);
m.def("leave_deferred_init", leaveDeferredInit);
-
m.def("can_materialize", canMaterialize);
+ m.def("is_gen_by_random_op", isGenByRandomOp);
m.def("materialize_tensor", materializeVariable);
+ m.def("materialize_tensor_with_local_shape", materializeVariableWithLocalShape);
}
} // namespace torchdistx::python
diff --git a/src/python/torchdistx/_C/fake.cc b/src/python/torchdistx/_C/fake.cc
index 70e0e3c..538a3b5 100644
--- a/src/python/torchdistx/_C/fake.cc
+++ b/src/python/torchdistx/_C/fake.cc
@@ -8,7 +8,12 @@
#include <ATen/Context.h>
#include <ATen/Tensor.h>
+#include <torch/torch.h>
+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR>=3
+#include <torch/csrc/utils/device_lazy_init.h>
+#else
#include <torch/csrc/utils/cuda_lazy_init.h>
+#endif
#include <torch/csrc/utils/pybind.h>
#include <torchdistx/fake.h>
@@ -22,7 +27,12 @@ void pyEnterFakeMode(bool fake_cuda) {
// subsystem which would fail and prevent us from instantiating CUDA devices.
if (fake_cuda) {
if (!at::hasCUDA()) {
+
+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR>=3
+ torch::utils::set_requires_device_init(at::kCUDA, false);
+#else
torch::utils::set_requires_cuda_init(false);
+#endif
}
}
}
@@ -31,7 +41,11 @@ void pyLeaveFakeMode() {
leaveFakeMode();
if (!isFakeModeActive() && !at::hasCUDA()) {
+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR>=3
+ torch::utils::set_requires_device_init(at::kCUDA, true);
+#else
torch::utils::set_requires_cuda_init(true);
+#endif
}
}