forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
THCCachingHostAllocator.cpp
282 lines (226 loc) · 7.07 KB
/
THCCachingHostAllocator.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
#include "THCCachingHostAllocator.h"
#include "THCStream.h"
#include <cuda_runtime_api.h>
#include <deque>
#include <memory>
#include <mutex>
#include <set>
#include <stdint.h>
#include <unordered_map>
#include <utility>
namespace {
typedef std::shared_ptr<THCStream> THCStreamPtr;
struct BlockSize
{
size_t size; // allocation size
void* ptr; // host memory pointer
BlockSize(size_t size, void* ptr=NULL) : size(size), ptr(ptr) {}
};
struct Block : public BlockSize
{
bool allocated; // true if the block is currently allocated
int event_count; // number of outstanding cuda events
std::set<THCStreamPtr> streams;
Block(size_t size, void* ptr, bool allocated) :
BlockSize(size, ptr), allocated(allocated), event_count(0), streams() {}
};
static bool BlockComparator(const BlockSize& a, const BlockSize& b)
{
// sort by size, break ties with pointer
if (a.size != b.size) {
return a.size < b.size;
}
return (uintptr_t)a.ptr < (uintptr_t)b.ptr;
}
struct HostAllocator
{
typedef bool (*Comparison)(const BlockSize&, const BlockSize&);
// lock around all operations
std::mutex mutex;
// blocks by pointer
std::unordered_map<void*, Block> blocks;
// pointers that are ready to be allocated (event_count=0)
std::set<BlockSize, Comparison> available;
// outstanding cuda events
std::deque<std::pair<cudaEvent_t, void*>> cuda_events;
HostAllocator() : available(BlockComparator) {}
cudaError_t malloc(void** ptr, size_t size)
{
std::lock_guard<std::mutex> lock(mutex);
// process outstanding cuda events which may have occurred
cudaError_t err = processEvents();
if (err != cudaSuccess) {
return err;
}
// search for the smallest block which can hold this allocation
BlockSize search_key(size);
auto it = available.lower_bound(search_key);
if (it != available.end()) {
Block& block = blocks.at(it->ptr);
THAssert(!block.allocated && block.event_count == 0);
block.allocated = true;
*ptr = block.ptr;
available.erase(it);
return cudaSuccess;
}
// note that cudaHostAlloc may not touch pointer if size is 0
*ptr = 0;
// allocate a new block if no cached allocation is found
err = cudaHostAlloc(ptr, size, cudaHostAllocDefault);
if (err != cudaSuccess) {
return err;
}
blocks.insert({*ptr, Block(size, *ptr, true)});
return cudaSuccess;
}
cudaError_t free(void* ptr)
{
std::lock_guard<std::mutex> lock(mutex);
if (!ptr) {
return cudaSuccess;
}
// process outstanding cuda events which may have occurred
cudaError_t err = processEvents();
if (err != cudaSuccess) {
return err;
}
auto it = blocks.find(ptr);
THAssert(it != blocks.end());
Block& block = it->second;
THAssert(block.allocated);
// free (on valid memory) shouldn't fail, so mark unallocated before
// we process the streams.
block.allocated = false;
// insert CUDA events for each stream on which this block was used. This
err = insertEvents(block);
if (err != cudaSuccess) {
return err;
}
if (block.event_count == 0) {
// the block can be re-used if there are no outstanding cuda events
available.insert(block);
}
return cudaSuccess;
}
cudaError_t recordEvent(void* ptr, THCStream *stream)
{
std::lock_guard<std::mutex> lock(mutex);
auto it = blocks.find(ptr);
if (it == blocks.end()) {
// ignore events for untracked pointers
return cudaSuccess;
}
Block& block = it->second;
THAssert(block.allocated);
THCStreamPtr stream_ptr(stream, &THCStream_free);
THCStream_retain(stream);
block.streams.insert(std::move(stream_ptr));
return cudaSuccess;
}
cudaError_t processEvents()
{
// Process outstanding cudaEvents. Events that are completed are removed
// from the queue, and the 'event_count' for the corresponding allocation
// is decremented. Stops at the first event which has not been completed.
// Since events on different devices or streams may occur out of order,
// the processing of some events may be delayed.
while (!cuda_events.empty()) {
auto& e = cuda_events.front();
cudaEvent_t event = e.first;
cudaError_t err = cudaEventQuery(event);
if (err == cudaErrorNotReady) {
break;
} else if (err != cudaSuccess) {
return err;
}
err = cudaEventDestroy(event);
if (err != cudaSuccess) {
return err;
}
Block& block = blocks.at(e.second);
block.event_count--;
if (block.event_count == 0 && !block.allocated) {
available.insert(block);
}
cuda_events.pop_front();
}
return cudaSuccess;
}
void emptyCache()
{
std::lock_guard<std::mutex> lock(mutex);
// remove events for freed blocks
for (auto it = cuda_events.begin(); it != cuda_events.end(); ++it) {
cudaEvent_t event = it->first;
Block& block = blocks.at(it->second);
if (!block.allocated) {
THCudaCheckWarn(cudaEventDestroy(event));
block.event_count--;
}
}
// all cuda_events have been processed
cuda_events.clear();
// clear list of available blocks
available.clear();
// free and erase non-allocated blocks
for (auto it = blocks.begin(); it != blocks.end();) {
Block& block = it->second;
if (!block.allocated) {
THCudaCheckWarn(cudaFreeHost(block.ptr));
it = blocks.erase(it);
} else {
++it;
}
}
}
cudaError_t insertEvents(Block& block)
{
cudaError_t err;
int prev_device;
err = cudaGetDevice(&prev_device);
if (err != cudaSuccess) return err;
std::set<THCStreamPtr> streams(std::move(block.streams));
for (auto it = streams.begin(); it != streams.end(); ++it) {
auto& stream = *it;
err = cudaSetDevice(THCStream_device(stream.get()));
if (err != cudaSuccess) break;
cudaEvent_t event;
err = cudaEventCreateWithFlags(&event, cudaEventDisableTiming);
if (err != cudaSuccess) break;
err = cudaEventRecord(event, THCStream_stream(stream.get()));
if (err != cudaSuccess) break;
block.event_count++;
cuda_events.emplace_back(event, block.ptr);
}
cudaSetDevice(prev_device);
return err;
}
};
} // namespace
static HostAllocator allocator;
cudaError_t THCCachingHostAllocator_recordEvent(void *ptr, THCStream *stream)
{
return allocator.recordEvent(ptr, stream);
}
void THCCachingHostAllocator_emptyCache()
{
allocator.emptyCache();
}
static void THCCachingHostDeleter(void* ptr) {
allocator.free(ptr);
}
struct THCCachingHostAllocator final : public at::Allocator {
at::DataPtr allocate(size_t size) const override {
THAssert(size >= 0);
void *ptr;
THCudaCheck(allocator.malloc(&ptr, size));
return {ptr, ptr, &THCCachingHostDeleter, at::DeviceType::CPU};
}
at::DeleterFnPtr raw_deleter() const override {
return &THCCachingHostDeleter;
}
};
static THCCachingHostAllocator thc_caching_host_allocator;
at::Allocator* getTHCCachingHostAllocator() {
return &thc_caching_host_allocator;
}