forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
blob.h
71 lines (62 loc) · 2.13 KB
/
blob.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#ifndef CAFFE2_CORE_BLOB_H_
#define CAFFE2_CORE_BLOB_H_
#include <cstddef>
#include <sstream>
#include <typeinfo>
#include <type_traits>
#include <vector>
#include "caffe2/core/common.h"
#include <ATen/core/blob.h>
#include <ATen/core/typeid.h>
#include "caffe2/core/logging.h"
#include "caffe2/core/tensor.h"
namespace caffe2 {
inline bool BlobIsTensorType(const Blob& blob, DeviceType device_type) {
bool is_match = blob.meta().Match<Tensor>();
if (!is_match) {
return false;
}
const Tensor* tensor = &blob.Get<Tensor>();
return tensor && *tensor && tensor->GetDeviceType() == device_type;
}
inline Tensor*
BlobGetMutableTensor(Blob* blob, at::IntList dims, at::TensorOptions options) {
if (blob->IsType<Tensor>()) {
Tensor* tensor = blob->GetMutable<Tensor>();
if (*tensor) {
if (tensor->GetDevice() == options.device()) {
if (tensor->sizes() != dims) {
// Resize when the dims doesn't match
tensor->Resize(dims);
}
if (tensor->dtype() == options.dtype()) {
tensor->raw_mutable_data();
} else {
// create a new Tensor when the data_type doesn't match
return blob->Reset<Tensor>(new Tensor(caffe2::empty(dims, options)));
}
return tensor;
}
// create a new Tensor when device doesn't match
}
}
VLOG(1) << "Create new mutable object " << TypeMeta::TypeName<Tensor>()
<< " dims: " << dims;
// << " options: " << options; (operator<< for Options is in at:: now)
return blob->Reset<Tensor>(new Tensor(caffe2::empty(dims, options)));
}
inline Tensor* BlobGetMutableTensor(Blob* blob, DeviceType device_type) {
if (blob->IsType<Tensor>()) {
Tensor* tensor = blob->GetMutable<Tensor>();
if (*tensor && tensor->GetDeviceType() == device_type) {
return tensor;
}
}
// if we're here, then either Blob didn't hold a Tensor
// or that Tensor had the wrong DeviceType.
VLOG(1) << "Create new mutable object " << TypeMeta::TypeName<Tensor>()
<< " DeviceType:" << device_type;
return blob->Reset<Tensor>(new Tensor(device_type));
}
} // namespace caffe2
#endif // CAFFE2_CORE_BLOB_H_