forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
THCStorage.cpp
65 lines (52 loc) · 1.68 KB
/
THCStorage.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include "THCStorage.hpp"
#include "THCGeneral.h"
#include "TH/THHalf.h"
#include <new>
#include "generic/THCStorage.cpp"
#include "THCGenerateAllTypes.h"
#include <ATen/core/intrusive_ptr.h>
void THCStorage_resize(THCState *state, THCStorage *self, ptrdiff_t size)
{
THArgCheck(size >= 0, 2, "invalid size");
THAssert(self->allocator() != nullptr);
int device;
THCudaCheck(cudaGetDevice(&device));
if (!self->resizable())
THError("Trying to resize storage that is not resizable");
size_t itemsize = self->itemsize();
if(size == 0)
{
self->set_data_ptr(at::DataPtr(nullptr, at::Device(at::DeviceType::CUDA, device)));
self->set_numel(0);
}
else
{
at::DataPtr data =
self->allocator()->allocate(size * itemsize);
if (self->data_ptr()) {
// Enable p2p access when the memcpy is across devices
THCState_getPeerToPeerAccess(state, device, THCStorage_getDevice(state, self));
THCudaCheck(cudaMemcpyAsync(data.get(),
self->data(),
THMin(self->numel(), size) * itemsize,
cudaMemcpyDeviceToDevice,
THCState_getCurrentStream(state)));
}
// Destructively overwrite data_ptr
self->set_data_ptr(std::move(data));
self->set_numel(size);
}
}
int THCStorage_getDevice(THCState* state, const THCStorage* storage) {
return storage->device().index();
}
THC_API THCStorage* THCStorage_new(
THCState* state,
caffe2::TypeMeta data_type) {
THStorage* storage = c10::make_intrusive<at::StorageImpl>(
data_type,
0,
state->cudaDeviceAllocator,
true).release();
return storage;
}