Skip to content

Commit

Permalink
Merge pull request nowsecure#68 from nowsecure/sqliterecovery
Browse files Browse the repository at this point in the history
Allow to ensure the presence of a primary key
  • Loading branch information
bellini666 committed Jun 12, 2015
2 parents 2e9337e + c769494 commit 84ce27f
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 27 deletions.
101 changes: 78 additions & 23 deletions datagrid_gtk3/db/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from sqlalchemy import (
Column,
INTEGER,
MetaData,
Table,
create_engine,
Expand Down Expand Up @@ -67,6 +68,9 @@ class SQLiteDataSource(DataSource):
names, data types, transforms, etc.
:param bool ensure_selected_column: Whether to ensure the presence of
the __selected column.
:param bool ensure_primary_key: if we should ensure the presence of
a primary key on the table. If this is `True` and no primary
key is found, one will be created for you.
:param bool display_all: Whether or not all columns should be displayed.
:param str query: Full custom query to be used instead of the table name.
:param bool persist_columns_visibility: Weather we should persist
Expand All @@ -92,7 +96,8 @@ class SQLiteDataSource(DataSource):
}

def __init__(self, db_file, table=None, update_table=None, config=None,
ensure_selected_column=True, display_all=False, query=None,
ensure_primary_key=True, ensure_selected_column=True,
display_all=False, query=None,
persist_columns_visibility=True):
"""Process database column info."""
super(SQLiteDataSource, self).__init__()
Expand All @@ -104,6 +109,7 @@ def __init__(self, db_file, table=None, update_table=None, config=None,
if query:
logger.debug("Custom SQL: %s", query)
self._persist_columns_visibility = persist_columns_visibility
self._ensure_primary_key = ensure_primary_key
self._ensure_selected_column = ensure_selected_column
self.display_all = display_all
# FIXME: Use sqlalchemy for queries using update_table
Expand Down Expand Up @@ -441,6 +447,64 @@ def _ensure_temp_view(self, cursor):
self.table.name, self.query
))

def _ensure_primary_key_column(self, conn):
"""Ensure that we know what is the primary key.
:param conn: an open connection to the database
"""
# FIXME: What to do when using temporary views?
if self.query:
return

with closing(conn.cursor()) as cursor:
cursor.execute('PRAGMA table_info(%s)' % (self.table.name, ))
rows = cursor.fetchall()

row_names = [row[1] for row in rows]

# First check if there's any row that matches self.ID_COLUMN
if self.ID_COLUMN in row_names:
return

# Then try to find any row that has its primary key flag set to True
for row in rows:
if row[5]: # primary key
self.ID_COLUMN = row[1]
return

# If nothing worked and we need to ensure primary key or selected
# column, lets add it. Since sqlite doesn't allow us to add primary
# keys to existing tables, we are working around that by creating a
# new one with a primary key, copying everything from the previous
# table there.
if self._ensure_primary_key or self._ensure_selected_column:
t_name = self.table.name
candidate = '__id'
while candidate in row_names:
candidate = '_' + candidate

db = Database(self.db_file)
table = db.reflect(t_name)
new_cols = [
Column(c.name, c.type,
primary_key=c.primary_key, default=c.default)
for c in table.columns.values()]
new_cols.append(Column(candidate, INTEGER, primary_key=True))
tmp_name = '__tmp_' + t_name
new_table = Table(tmp_name, db.metadata, *new_cols)
new_table.create()

with closing(conn.cursor()) as cursor:
cursor.execute(
"INSERT INTO %s SELECT *, null FROM %s" % (
tmp_name, t_name))
cursor.execute("DROP TABLE %s" % (t_name, ))
cursor.execute(
"ALTER TABLE %s RENAME TO %s" % (tmp_name, t_name))
conn.commit()

self.ID_COLUMN = candidate

def get_columns(self):
"""Return a list of column information dicts.
Expand All @@ -461,25 +525,14 @@ def get_columns(self):
"""
cols = []
with closing(sqlite3.connect(self.db_file)) as conn:
self._ensure_primary_key_column(conn)

with closing(conn.cursor()) as cursor:
self._ensure_temp_view(cursor)
table_info_query = 'PRAGMA table_info(%s)' % self.table.name
cursor.execute(table_info_query)
rows = cursor.fetchall()

# FIXME: If the idcolumn doesn't match any column, use the
# first primary key we can find. This actually happen on the
# examples database.
if not any(row[1] == self.ID_COLUMN for row in rows):
for row in rows:
if row[5]: # primary key
self.ID_COLUMN = row[1]
break
else:
if self._ensure_selected_column:
# ID column is necessary row selection
raise ValueError("No id column found.")

has_selected = False
counter = 0
for i, row in enumerate(rows):
Expand Down Expand Up @@ -512,6 +565,7 @@ def get_columns(self):
'transform_options': options,
'expand': expand,
'visible': visible,
'from_config': self.config is not None,
}

if col_name == self.ID_COLUMN:
Expand Down Expand Up @@ -543,6 +597,7 @@ def get_columns(self):
'transform_options': 'boolean',
'expand': False,
'visible': True,
'from_config': False,
}
cols.insert(0, col_dict)
has_selected = True
Expand Down Expand Up @@ -737,19 +792,19 @@ def run_quick_check(self):
logger.warning('Integrity check failure: %s', self.db_filename)
return passed

def reflect(self):
def reflect(self, table_name):
"""Get table metadata through reflection.
sqlalchemy already provides a reflect method, but it will stop at the
first failure, while this method will try to get as much as possible.
"""
inspector = inspect(self.engine)
for table_name in inspector.get_table_names():
columns = []
for column_data in inspector.get_columns(table_name):
# Rename 'type' to 'type_' to create column object
column_type = column_data.pop('type', None)
column_data['type_'] = column_type
columns.append(Column(**column_data))
Table(table_name, self.metadata, *columns)
columns = []
for column_data in inspector.get_columns(table_name):
# Rename 'type' to 'type_' to create column object
column_type = column_data.pop('type', None)
column_data['type_'] = column_type
columns.append(Column(**column_data))

return Table(table_name, self.metadata, *columns)
8 changes: 5 additions & 3 deletions datagrid_gtk3/tests/test_datagrid-gtk3.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,9 @@ class TransformationsTest(unittest.TestCase):
def setUp(self): # noqa
"""Create test data."""
self.datagrid_model = DataGridModel(
data_source=SQLiteDataSource('', 'test',
ensure_selected_column=False),
data_source=SQLiteDataSource(
'', 'test',
ensure_selected_column=False, ensure_primary_key=False),
get_media_callback=mock.MagicMock(),
decode_fallback=mock.MagicMock()
)
Expand Down Expand Up @@ -684,7 +685,8 @@ def _transform(self, transform_type, value, transform_options=None):
self.datagrid_model.columns = [
{'name': transform_type,
'transform': transform_type,
'transform_options': transform_options}]
'transform_options': transform_options,
'from_config': True}]
return self.datagrid_model.get_formatted_value(value, 0)


Expand Down
7 changes: 6 additions & 1 deletion datagrid_gtk3/ui/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -1717,7 +1717,12 @@ def get_formatted_value(self, value, column_index, visible=True):
transformer = get_transformer(transformer_name)
transformer_kwargs = {}

if value is not None and 'type' in col_dict:
# Only enforce value type if the config was provided. Otherwise,
# we would just be spamming a lot of obvious warnings (we got the type
# from introspecting the database and for sqlite, it has a high
# probability of not being an exact match in python).
if (col_dict['from_config'] and
value is not None and 'type' in col_dict):
# Try enforcing value type
value = self._enforce_value_type(value, col_dict['type'])

Expand Down

0 comments on commit 84ce27f

Please sign in to comment.