forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Allocator.cpp
36 lines (27 loc) · 930 Bytes
/
Allocator.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#include <ATen/core/Allocator.h>
namespace at {
static void deleteInefficientStdFunctionContext(void* ptr) {
delete static_cast<InefficientStdFunctionContext*>(ptr);
}
at::DataPtr InefficientStdFunctionContext::makeDataPtr(
void* ptr,
const std::function<void(void*)>& deleter,
Device device) {
return {ptr,
new InefficientStdFunctionContext({ptr, deleter}),
&deleteInefficientStdFunctionContext,
device};
}
} // namespace at
namespace caffe2 {
CAFFE2_API at::Allocator* allocator_array[static_cast<int>(
at::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)];
void SetAllocator(at::DeviceType t, at::Allocator* alloc) {
allocator_array[static_cast<int>(t)] = alloc;
}
at::Allocator* GetAllocator(const at::DeviceType& t) {
auto* alloc = allocator_array[static_cast<int>(t)];
AT_ASSERTM(alloc, "Allocator for ", t, " is not set.");
return alloc;
}
} // namespace caffe2