Skip to content

Commit

Permalink
Add basic transaction db
Browse files Browse the repository at this point in the history
  • Loading branch information
mharshe committed Nov 29, 2021
1 parent 7e56c38 commit 9849c87
Show file tree
Hide file tree
Showing 3 changed files with 336 additions and 4 deletions.
237 changes: 233 additions & 4 deletions rocksdb/_rocksdb.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ from . cimport table_factory
from . cimport memtablerep
from . cimport universal_compaction
from . cimport metadata
from . cimport stackable_db
from . cimport transaction_db

# Enums are the only exception for direct imports
# Their name als already unique enough
Expand Down Expand Up @@ -1784,6 +1784,86 @@ cdef class Options(ColumnFamilyOptions):



cdef class TransactionDBOptions(object):
cdef transaction_db.TransactionDBOptions* opts

def __cinit__(self):
self.opts = new transaction_db.TransactionDBOptions()

def __dealloc__(self):
if not self.opts == NULL:
del self.opts

def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)

property max_num_locks:
def __get__(self):
return self.opts.max_num_locks
def __set__(self, value):
self.opts.max_num_locks = value

property max_num_deadlocks:
def __get__(self):
return self.opts.max_num_deadlocks
def __set__(self, value):
self.opts.max_num_deadlocks = value

property num_stripes:
def __get__(self):
return self.opts.num_stripes
def __set__(self, value):
self.opts.num_stripes = value

property transaction_lock_timeout:
def __get__(self):
return self.opts.transaction_lock_timeout
def __set__(self, value):
self.opts.transaction_lock_timeout = value

property default_lock_timeout:
def __get__(self):
return self.opts.default_lock_timeout
def __set__(self, value):
self.opts.default_lock_timeout = value

# TODO property custom_mutex_factory
property write_policy:
def __get__(self):
if self.opts.write_policy == transaction_db.WRITE_COMMITTED:
return 'write_committed'
if self.opts.write_policy == transaction_db.WRITE_PREPARED:
return 'write_prepared'
if self.opts.write_policy == transaction_db.WRITE_UNPREPARED:
return 'write_unprepared'
raise Exception("Unknown write policy")

def __set__(self, str value):
if value == 'write_committed':
self.opts.write_policy = transaction_db.WRITE_COMMITTED
elif value == 'write_prepared':
self.opts.write_policy = transaction_db.WRITE_PREPARED
elif value == 'write_unprepared':
self.opts.write_policy = transaction_db.WRITE_UNPREPARED
else:
raise Exception("Unknown write policy")

property rollback_merge_operands:
def __get__(self):
return self.opts.rollback_merge_operands
def __set__(self, value):
self.opts.rollback_merge_operands = value
property skip_concurrency_control:
def __get__(self):
return self.opts.skip_concurrency_control
def __set__(self, value):
self.opts.skip_concurrency_control = value
property default_write_batch_flush_threshold:
def __get__(self):
return self.opts.default_write_batch_flush_threshold
def __set__(self, value):
self.opts.default_write_batch_flush_threshold = value

# Forward declaration
cdef class Snapshot
Expand Down Expand Up @@ -2587,9 +2667,158 @@ def list_column_families(db_name, Options opts):
return column_families

@cython.no_gc_clear
cdef class StackableDB(DB):
def __cinit__(self, db_name, Options opts, dict column_families=None, read_only=False):
self.db = new stackable_db.StackableDB(self.db)
cdef class TransactionDB(object):
cdef Options opts
cdef transaction_db.TransactionDB* db
cdef list cf_handles
cdef list cf_options

def __cinit__(self, db_name, Options opts, TransactionDBOptions tdb_opts, dict column_families=None):
cdef Status st
cdef string db_path
cdef vector[db.ColumnFamilyDescriptor] column_family_descriptors
cdef vector[db.ColumnFamilyHandle*] column_family_handles
cdef bytes default_cf_name = db.kDefaultColumnFamilyName
self.db = NULL
self.opts = None
self.cf_handles = []
self.cf_options = []

if opts.in_use:
raise Exception("Options object is already used by another DB")

db_path = path_to_string(db_name)
if not column_families or default_cf_name not in column_families:
# Always add the default column family
column_family_descriptors.push_back(
db.ColumnFamilyDescriptor(
db.kDefaultColumnFamilyName,
options.ColumnFamilyOptions(deref(opts.opts))
)
)
self.cf_options.append(None) # Since they are the same as db
if column_families:
for cf_name, cf_options in column_families.items():
if not isinstance(cf_name, bytes):
raise TypeError(
f"column family name {cf_name!r} is not of type {bytes}!"
)
if not isinstance(cf_options, ColumnFamilyOptions):
raise TypeError(
f"column family options {cf_options!r} is not of type "
f"{ColumnFamilyOptions}!"
)
if (<ColumnFamilyOptions>cf_options).in_use:
raise Exception(
f"ColumnFamilyOptions object for {cf_name} is already "
"used by another Column Family"
)
(<ColumnFamilyOptions>cf_options).in_use = True
column_family_descriptors.push_back(
db.ColumnFamilyDescriptor(
cf_name,
deref((<ColumnFamilyOptions>cf_options).copts)
)
)
self.cf_options.append(cf_options)
if column_families:
for cf_name, cf_options in column_families.items():
if not isinstance(cf_name, bytes):
raise TypeError(
f"column family name {cf_name!r} is not of type {bytes}!"
)
if not isinstance(cf_options, ColumnFamilyOptions):
raise TypeError(
f"column family options {cf_options!r} is not of type "
f"{ColumnFamilyOptions}!"
)
if (<ColumnFamilyOptions>cf_options).in_use:
raise Exception(
f"ColumnFamilyOptions object for {cf_name} is already "
"used by another Column Family"
)
(<ColumnFamilyOptions>cf_options).in_use = True
column_family_descriptors.push_back(
db.ColumnFamilyDescriptor(
cf_name,
deref((<ColumnFamilyOptions>cf_options).copts)
)
)
self.cf_options.append(cf_options)

with nogil:
st = transaction_db.TransactionDB_Open_ColumnFamilies(
deref(opts.opts),
deref(tdb_opts.opts),
db_path,
column_family_descriptors,
&column_family_handles,
&self.db)
check_status(st)

for handle in column_family_handles:
wrapper = _ColumnFamilyHandle.from_handle_ptr(handle)
self.cf_handles.append(wrapper)

# Inject the loggers into the python callbacks
cdef shared_ptr[logger.Logger] info_log = self.db.GetOptions(
self.db.DefaultColumnFamily()).info_log
if opts.py_comparator is not None:
opts.py_comparator.set_info_log(info_log)

if opts.py_table_factory is not None:
opts.py_table_factory.set_info_log(info_log)

if opts.prefix_extractor is not None:
opts.py_prefix_extractor.set_info_log(info_log)

cdef ColumnFamilyOptions copts
for idx, copts in enumerate(self.cf_options):
if not copts:
continue

info_log = self.db.GetOptions(column_family_handles[idx]).info_log

if copts.py_comparator is not None:
copts.py_comparator.set_info_log(info_log)

if copts.py_table_factory is not None:
copts.py_table_factory.set_info_log(info_log)

if copts.prefix_extractor is not None:
copts.py_prefix_extractor.set_info_log(info_log)

self.opts = opts
self.opts.in_use = True

def close(self, safe=True):
cdef ColumnFamilyOptions copts
cdef cpp_bool c_safe = safe
cdef Status st
if self.db != NULL:
# We need stop backround compactions
with nogil:
db.CancelAllBackgroundWork(self.db, c_safe)
# We have to make sure we delete the handles so rocksdb doesn't
# assert when we delete the db
del self.cf_handles[:]
for copts in self.cf_options:
if copts:
copts.in_use = False
del self.cf_options[:]
with nogil:
st = self.db.Close()
self.db = NULL
if self.opts is not None:
self.opts.in_use = False

def __dealloc__(self):
self.close()

property options:
def __get__(self):
return self.opts


@cython.no_gc_clear
@cython.internal
Expand Down
29 changes: 29 additions & 0 deletions rocksdb/tests/test_transaction_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os
import sys
import shutil
import gc
import unittest
import rocksdb
from itertools import takewhile
import struct
import tempfile
from rocksdb.merge_operators import UintAddOperator, StringAppendOperator

from .test_db import TestHelper

class TestTransactionDB(TestHelper):
def setUp(self):
TestHelper.setUp(self)
opts = rocksdb.Options(create_if_missing=True)
tdb_opts = rocksdb.TransactionDBOptions()
self.db = rocksdb.TransactionDB(os.path.join("/tmp", "test"), opts, tdb_opts)

def test_options_used_twice(self):
if sys.version_info[0] == 3:
assertRaisesRegex = self.assertRaisesRegex
else:
assertRaisesRegex = self.assertRaisesRegexp
expected = "Options object is already used by another DB"
tdb_opts = rocksdb.TransactionDBOptions()
with assertRaisesRegex(Exception, expected):
rocksdb.TransactionDB(os.path.join(self.db_loc, "test2"), self.db.options, tdb_opts)
74 changes: 74 additions & 0 deletions rocksdb/transaction_db.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from . cimport options
from libc.stdint cimport uint64_t, uint32_t, int64_t
from .status cimport Status
from libcpp cimport bool as cpp_bool
from libcpp.string cimport string
from libcpp.vector cimport vector
from libcpp.map cimport map
from libcpp.unordered_map cimport unordered_map
from libcpp.memory cimport shared_ptr
from .types cimport SequenceNumber
from .slice_ cimport Slice
from .snapshot cimport Snapshot
from .iterator cimport Iterator
from .env cimport Env
from .metadata cimport ColumnFamilyMetaData
from .metadata cimport LiveFileMetaData
from .metadata cimport ExportImportFilesMetaData
from .table_properties cimport TableProperties
from .db cimport DB, WriteBatch, ColumnFamilyDescriptor, ColumnFamilyHandle
from .stackable_db cimport StackableDB

cdef extern from "rocksdb/utilities/transaction_db.h" namespace "rocksdb":
cpdef enum TxnDBWritePolicy:
WRITE_COMMITTED
WRITE_PREPARED
WRITE_UNPREPARED

cdef cppclass TransactionDBOptions:
int64_t max_num_locks
uint32_t max_num_deadlocks
size_t num_stripes
int64_t transaction_lock_timeout
int64_t default_lock_timeout
# TODO shared_ptr[TransactionDBMutexFactory] custom_mutex_factory
TxnDBWritePolicy write_policy
cpp_bool rollback_merge_operands
cpp_bool skip_concurrency_control
int64_t default_write_batch_flush_threshold

cdef cppclass TransactionOptions:
cpp_bool set_snapshot
cpp_bool deadlock_detect
cpp_bool use_only_the_last_commit_time_batch_for_recovery
int64_t lock_timeout
int64_t expiration
int64_t deadlock_detect_depth
size_t max_write_batch_size
cpp_bool skip_concurrency_control
cpp_bool skip_prepare
int64_t write_batch_flush_threshold

cdef cppclass TransactionDBWriteOptimizations:
cpp_bool skip_concurrency_control
cpp_bool skip_duplicate_key_check

cdef cppclass TransactionDB(StackableDB):
Status Write(const options.WriteOptions&,
const TransactionDBWriteOptimizations&,
WriteBatch*) nogil except+

cdef Status TransactionDB_Open "rocksdb::TransactionDB::Open"(
const options.Options&,
const TransactionDBOptions&,
const string&,
TransactionDB**) nogil except+

cdef Status TransactionDB_Open_ColumnFamilies "rocksdb::TransactionDB::Open"(
const options.DBOptions&,
const TransactionDBOptions&,
const string&,
const vector[ColumnFamilyDescriptor]&,
vector[ColumnFamilyHandle*]*,
TransactionDB**) nogil except+

0 comments on commit 9849c87

Please sign in to comment.