forked from ra1nb0rn/cpe_search
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdatabase_wrapper_functions.py
60 lines (53 loc) · 2.23 KB
/
database_wrapper_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import sqlite3
try: # only use mariadb module if installed
import mariadb
except ImportError:
pass
CONNECTION_POOL_SIZE = os.cpu_count() # should be equal to number of cpu cores? (https://dba.stackexchange.com/a/305726)
CONNECTION_POOLS = {}
def get_database_connection(config_database_keys, database_name, use_pool=True):
'''Return a database connection object, initialized with the given config'''
database_type = config_database_keys['TYPE']
db_conn = None
if database_type == 'sqlite':
db_conn = sqlite3.connect(database_name)
elif database_type == 'mariadb':
# try to use connection pools
pool_name = 'pool_' + database_name
if pool_name in CONNECTION_POOLS:
try:
db_conn = CONNECTION_POOLS[pool_name].get_connection()
except:
# no connection in pool available
db_conn = mariadb.connect(
user=config_database_keys['USER'],
password=config_database_keys['PASSWORD'],
host=config_database_keys['HOST'],
port=config_database_keys['PORT'],
database=database_name
)
elif use_pool:
conn_params = {
'user': config_database_keys['USER'],
'password': config_database_keys['PASSWORD'],
'host': config_database_keys['HOST'],
'port': config_database_keys['PORT'],
'database': database_name
}
pool = mariadb.ConnectionPool(pool_name=pool_name, pool_size=CONNECTION_POOL_SIZE, **conn_params)
CONNECTION_POOLS[pool_name] = pool
return get_database_connection(config_database_keys, database_name)
else:
db_conn = mariadb.connect(
user=config_database_keys['USER'],
password=config_database_keys['PASSWORD'],
host=config_database_keys['HOST'],
port=config_database_keys['PORT'],
database=database_name
)
else:
raise(Exception('Invalid database type %s given' % (database_type)))
return db_conn
def get_connection_pools():
return CONNECTION_POOLS