diff --git a/datagrid_gtk3/db/sqlite.py b/datagrid_gtk3/db/sqlite.py index 71a977b..803f849 100644 --- a/datagrid_gtk3/db/sqlite.py +++ b/datagrid_gtk3/db/sqlite.py @@ -8,6 +8,7 @@ from sqlalchemy import ( Column, + INTEGER, MetaData, Table, create_engine, @@ -67,6 +68,9 @@ class SQLiteDataSource(DataSource): names, data types, transforms, etc. :param bool ensure_selected_column: Whether to ensure the presence of the __selected column. + :param bool ensure_primary_key: if we should ensure the presence of + a primary key on the table. If this is `True` and no primary + key is found, one will be created for you. :param bool display_all: Whether or not all columns should be displayed. :param str query: Full custom query to be used instead of the table name. :param bool persist_columns_visibility: Weather we should persist @@ -92,7 +96,8 @@ class SQLiteDataSource(DataSource): } def __init__(self, db_file, table=None, update_table=None, config=None, - ensure_selected_column=True, display_all=False, query=None, + ensure_primary_key=True, ensure_selected_column=True, + display_all=False, query=None, persist_columns_visibility=True): """Process database column info.""" super(SQLiteDataSource, self).__init__() @@ -104,6 +109,7 @@ def __init__(self, db_file, table=None, update_table=None, config=None, if query: logger.debug("Custom SQL: %s", query) self._persist_columns_visibility = persist_columns_visibility + self._ensure_primary_key = ensure_primary_key self._ensure_selected_column = ensure_selected_column self.display_all = display_all # FIXME: Use sqlalchemy for queries using update_table @@ -441,6 +447,64 @@ def _ensure_temp_view(self, cursor): self.table.name, self.query )) + def _ensure_primary_key_column(self, conn): + """Ensure that we know what is the primary key. + + :param conn: an open connection to the database + """ + # FIXME: What to do when using temporary views? + if self.query: + return + + with closing(conn.cursor()) as cursor: + cursor.execute('PRAGMA table_info(%s)' % (self.table.name, )) + rows = cursor.fetchall() + + row_names = [row[1] for row in rows] + + # First check if there's any row that matches self.ID_COLUMN + if self.ID_COLUMN in row_names: + return + + # Then try to find any row that has its primary key flag set to True + for row in rows: + if row[5]: # primary key + self.ID_COLUMN = row[1] + return + + # If nothing worked and we need to ensure primary key or selected + # column, lets add it. Since sqlite doesn't allow us to add primary + # keys to existing tables, we are working around that by creating a + # new one with a primary key, copying everything from the previous + # table there. + if self._ensure_primary_key or self._ensure_selected_column: + t_name = self.table.name + candidate = '__id' + while candidate in row_names: + candidate = '_' + candidate + + db = Database(self.db_file) + table = db.reflect(t_name) + new_cols = [ + Column(c.name, c.type, + primary_key=c.primary_key, default=c.default) + for c in table.columns.values()] + new_cols.append(Column(candidate, INTEGER, primary_key=True)) + tmp_name = '__tmp_' + t_name + new_table = Table(tmp_name, db.metadata, *new_cols) + new_table.create() + + with closing(conn.cursor()) as cursor: + cursor.execute( + "INSERT INTO %s SELECT *, null FROM %s" % ( + tmp_name, t_name)) + cursor.execute("DROP TABLE %s" % (t_name, )) + cursor.execute( + "ALTER TABLE %s RENAME TO %s" % (tmp_name, t_name)) + conn.commit() + + self.ID_COLUMN = candidate + def get_columns(self): """Return a list of column information dicts. @@ -461,25 +525,14 @@ def get_columns(self): """ cols = [] with closing(sqlite3.connect(self.db_file)) as conn: + self._ensure_primary_key_column(conn) + with closing(conn.cursor()) as cursor: self._ensure_temp_view(cursor) table_info_query = 'PRAGMA table_info(%s)' % self.table.name cursor.execute(table_info_query) rows = cursor.fetchall() - # FIXME: If the idcolumn doesn't match any column, use the - # first primary key we can find. This actually happen on the - # examples database. - if not any(row[1] == self.ID_COLUMN for row in rows): - for row in rows: - if row[5]: # primary key - self.ID_COLUMN = row[1] - break - else: - if self._ensure_selected_column: - # ID column is necessary row selection - raise ValueError("No id column found.") - has_selected = False counter = 0 for i, row in enumerate(rows): @@ -512,6 +565,7 @@ def get_columns(self): 'transform_options': options, 'expand': expand, 'visible': visible, + 'from_config': self.config is not None, } if col_name == self.ID_COLUMN: @@ -543,6 +597,7 @@ def get_columns(self): 'transform_options': 'boolean', 'expand': False, 'visible': True, + 'from_config': False, } cols.insert(0, col_dict) has_selected = True @@ -737,7 +792,7 @@ def run_quick_check(self): logger.warning('Integrity check failure: %s', self.db_filename) return passed - def reflect(self): + def reflect(self, table_name): """Get table metadata through reflection. sqlalchemy already provides a reflect method, but it will stop at the @@ -745,11 +800,11 @@ def reflect(self): """ inspector = inspect(self.engine) - for table_name in inspector.get_table_names(): - columns = [] - for column_data in inspector.get_columns(table_name): - # Rename 'type' to 'type_' to create column object - column_type = column_data.pop('type', None) - column_data['type_'] = column_type - columns.append(Column(**column_data)) - Table(table_name, self.metadata, *columns) + columns = [] + for column_data in inspector.get_columns(table_name): + # Rename 'type' to 'type_' to create column object + column_type = column_data.pop('type', None) + column_data['type_'] = column_type + columns.append(Column(**column_data)) + + return Table(table_name, self.metadata, *columns) diff --git a/datagrid_gtk3/tests/test_datagrid-gtk3.py b/datagrid_gtk3/tests/test_datagrid-gtk3.py index dcab904..86b83c1 100644 --- a/datagrid_gtk3/tests/test_datagrid-gtk3.py +++ b/datagrid_gtk3/tests/test_datagrid-gtk3.py @@ -354,8 +354,9 @@ class TransformationsTest(unittest.TestCase): def setUp(self): # noqa """Create test data.""" self.datagrid_model = DataGridModel( - data_source=SQLiteDataSource('', 'test', - ensure_selected_column=False), + data_source=SQLiteDataSource( + '', 'test', + ensure_selected_column=False, ensure_primary_key=False), get_media_callback=mock.MagicMock(), decode_fallback=mock.MagicMock() ) @@ -684,7 +685,8 @@ def _transform(self, transform_type, value, transform_options=None): self.datagrid_model.columns = [ {'name': transform_type, 'transform': transform_type, - 'transform_options': transform_options}] + 'transform_options': transform_options, + 'from_config': True}] return self.datagrid_model.get_formatted_value(value, 0) diff --git a/datagrid_gtk3/ui/grid.py b/datagrid_gtk3/ui/grid.py index 3e544f4..493a30e 100644 --- a/datagrid_gtk3/ui/grid.py +++ b/datagrid_gtk3/ui/grid.py @@ -1717,7 +1717,12 @@ def get_formatted_value(self, value, column_index, visible=True): transformer = get_transformer(transformer_name) transformer_kwargs = {} - if value is not None and 'type' in col_dict: + # Only enforce value type if the config was provided. Otherwise, + # we would just be spamming a lot of obvious warnings (we got the type + # from introspecting the database and for sqlite, it has a high + # probability of not being an exact match in python). + if (col_dict['from_config'] and + value is not None and 'type' in col_dict): # Try enforcing value type value = self._enforce_value_type(value, col_dict['type'])