forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
db.h
312 lines (279 loc) · 8.52 KB
/
db.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
#ifndef CAFFE2_CORE_DB_H_
#define CAFFE2_CORE_DB_H_
#include <mutex>
#include "c10/util/Registry.h"
#include "caffe2/core/blob_serialization.h"
#include "caffe2/proto/caffe2_pb.h"
namespace caffe2 {
namespace db {
/**
* The mode of the database, whether we are doing a read, write, or creating
* a new database.
*/
enum Mode { READ, WRITE, NEW };
/**
* An abstract class for the cursor of the database while reading.
*/
class CAFFE2_API Cursor {
public:
Cursor() { }
virtual ~Cursor() { }
/**
* Seek to a specific key (or if the key does not exist, seek to the
* immediate next). This is optional for dbs, and in default, SupportsSeek()
* returns false meaning that the db cursor does not support it.
*/
virtual void Seek(const string& key) = 0;
virtual bool SupportsSeek() { return false; }
/**
* Seek to the first key in the database.
*/
virtual void SeekToFirst() = 0;
/**
* Go to the next location in the database.
*/
virtual void Next() = 0;
/**
* Returns the current key.
*/
virtual string key() = 0;
/**
* Returns the current value.
*/
virtual string value() = 0;
/**
* Returns whether the current location is valid - for example, if we have
* reached the end of the database, return false.
*/
virtual bool Valid() = 0;
C10_DISABLE_COPY_AND_ASSIGN(Cursor);
};
/**
* An abstract class for the current database transaction while writing.
*/
class CAFFE2_API Transaction {
public:
Transaction() { }
virtual ~Transaction() { }
/**
* Puts the key value pair to the database.
*/
virtual void Put(const string& key, const string& value) = 0;
/**
* Commits the current writes.
*/
virtual void Commit() = 0;
C10_DISABLE_COPY_AND_ASSIGN(Transaction);
};
/**
* An abstract class for accessing a database of key-value pairs.
*/
class CAFFE2_API DB {
public:
DB(const string& /*source*/, Mode mode) : mode_(mode) {}
virtual ~DB() { }
/**
* Closes the database.
*/
virtual void Close() = 0;
/**
* Returns a cursor to read the database. The caller takes the ownership of
* the pointer.
*/
virtual std::unique_ptr<Cursor> NewCursor() = 0;
/**
* Returns a transaction to write data to the database. The caller takes the
* ownership of the pointer.
*/
virtual std::unique_ptr<Transaction> NewTransaction() = 0;
protected:
Mode mode_;
C10_DISABLE_COPY_AND_ASSIGN(DB);
};
// Database classes are registered by their names so we can do optional
// dependencies.
C10_DECLARE_REGISTRY(Caffe2DBRegistry, DB, const string&, Mode);
#define REGISTER_CAFFE2_DB(name, ...) \
C10_REGISTER_CLASS(Caffe2DBRegistry, name, __VA_ARGS__)
/**
* Returns a database object of the given database type, source and mode. The
* caller takes the ownership of the pointer. If the database type is not
* supported, a nullptr is returned. The caller is responsible for examining the
* validity of the pointer.
*/
inline unique_ptr<DB> CreateDB(
const string& db_type, const string& source, Mode mode) {
auto result = Caffe2DBRegistry()->Create(db_type, source, mode);
VLOG(1) << ((!result) ? "not found db " : "found db ") << db_type;
return result;
}
/**
* Returns whether or not a database exists given the database type and path.
*/
inline bool DBExists(const string& db_type, const string& full_db_name) {
// Warning! We assume that creating a DB throws an exception if the DB
// does not exist. If the DB constructor does not follow this design
// pattern,
// the returned output (the existence tensor) can be wrong.
try {
std::unique_ptr<DB> db(
caffe2::db::CreateDB(db_type, full_db_name, caffe2::db::READ));
return true;
} catch (...) {
return false;
}
}
/**
* A reader wrapper for DB that also allows us to serialize it.
*/
class CAFFE2_API DBReader {
public:
friend class DBReaderSerializer;
DBReader() {}
DBReader(
const string& db_type,
const string& source,
const int32_t num_shards = 1,
const int32_t shard_id = 0) {
Open(db_type, source, num_shards, shard_id);
}
explicit DBReader(const DBReaderProto& proto) {
Open(proto.db_type(), proto.source());
if (proto.has_key()) {
CAFFE_ENFORCE(cursor_->SupportsSeek(),
"Encountering a proto that needs seeking but the db type "
"does not support it.");
cursor_->Seek(proto.key());
}
num_shards_ = 1;
shard_id_ = 0;
}
explicit DBReader(std::unique_ptr<DB> db)
: db_type_("<memory-type>"),
source_("<memory-source>"),
db_(std::move(db)) {
CAFFE_ENFORCE(db_.get(), "Passed null db");
cursor_ = db_->NewCursor();
}
void Open(
const string& db_type,
const string& source,
const int32_t num_shards = 1,
const int32_t shard_id = 0) {
// Note(jiayq): resetting is needed when we re-open e.g. leveldb where no
// concurrent access is allowed.
cursor_.reset();
db_.reset();
db_type_ = db_type;
source_ = source;
db_ = CreateDB(db_type_, source_, READ);
CAFFE_ENFORCE(db_, "Cannot open db: ", source_, " of type ", db_type_);
InitializeCursor(num_shards, shard_id);
}
void Open(
unique_ptr<DB>&& db,
const int32_t num_shards = 1,
const int32_t shard_id = 0) {
cursor_.reset();
db_.reset();
db_ = std::move(db);
CAFFE_ENFORCE(db_.get(), "Passed null db");
InitializeCursor(num_shards, shard_id);
}
public:
/**
* Read a set of key and value from the db and move to next. Thread safe.
*
* The string objects key and value must be created by the caller and
* explicitly passed in to this function. This saves one additional object
* copy.
*
* If the cursor reaches its end, the reader will go back to the head of
* the db. This function can be used to enable multiple input ops to read
* the same db.
*
* Note(jiayq): we loosen the definition of a const function here a little
* bit: the state of the cursor is actually changed. However, this allows
* us to pass in a DBReader to an Operator without the need of a duplicated
* output blob.
*/
void Read(string* key, string* value) const {
CAFFE_ENFORCE(cursor_ != nullptr, "Reader not initialized.");
std::unique_lock<std::mutex> mutex_lock(reader_mutex_);
*key = cursor_->key();
*value = cursor_->value();
// In sharded mode, each read skips num_shards_ records
for (uint32_t s = 0; s < num_shards_; s++) {
cursor_->Next();
if (!cursor_->Valid()) {
MoveToBeginning();
break;
}
}
}
/**
* @brief Seeks to the first key. Thread safe.
*/
void SeekToFirst() const {
CAFFE_ENFORCE(cursor_ != nullptr, "Reader not initialized.");
std::unique_lock<std::mutex> mutex_lock(reader_mutex_);
MoveToBeginning();
}
/**
* Returns the underlying cursor of the db reader.
*
* Note that if you directly use the cursor, the read will not be thread
* safe, because there is no mechanism to stop multiple threads from
* accessing the same cursor. You should consider using Read() explicitly.
*/
inline Cursor* cursor() const {
VLOG(1) << "Usually for a DBReader you should use Read() to be "
"thread safe. Consider refactoring your code.";
return cursor_.get();
}
private:
void InitializeCursor(const int32_t num_shards, const int32_t shard_id) {
CAFFE_ENFORCE(num_shards >= 1);
CAFFE_ENFORCE(shard_id >= 0);
CAFFE_ENFORCE(shard_id < num_shards);
num_shards_ = num_shards;
shard_id_ = shard_id;
cursor_ = db_->NewCursor();
SeekToFirst();
}
void MoveToBeginning() const {
cursor_->SeekToFirst();
for (uint32_t s = 0; s < shard_id_; s++) {
cursor_->Next();
CAFFE_ENFORCE(
cursor_->Valid(), "Db has fewer rows than shard id: ", s, shard_id_);
}
}
string db_type_;
string source_;
unique_ptr<DB> db_;
unique_ptr<Cursor> cursor_;
mutable std::mutex reader_mutex_;
uint32_t num_shards_;
uint32_t shard_id_;
C10_DISABLE_COPY_AND_ASSIGN(DBReader);
};
class CAFFE2_API DBReaderSerializer : public BlobSerializerBase {
public:
/**
* Serializes a DBReader. Note that this blob has to contain DBReader,
* otherwise this function produces a fatal error.
*/
void Serialize(
const void* pointer,
TypeMeta typeMeta,
const string& name,
BlobSerializerBase::SerializationAcceptor acceptor) override;
};
class CAFFE2_API DBReaderDeserializer : public BlobDeserializerBase {
public:
void Deserialize(const BlobProto& proto, Blob* blob) override;
};
} // namespace db
} // namespace caffe2
#endif // CAFFE2_CORE_DB_H_