Skip to content

Commit

Permalink
fix: Editor page with redundant table fields of the same name in othe… (
Browse files Browse the repository at this point in the history
#1765)

Co-authored-by: 王玉东 <[email protected]>
  • Loading branch information
whyuds and 王玉东 authored Aug 5, 2024
1 parent 5bd946f commit 9fe060a
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 14 deletions.
2 changes: 1 addition & 1 deletion dbgpt/app/openapi/api_v1/editor/api_editor_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ async def get_editor_tables(
for table in tables:
table_node: DataNode = DataNode(title=table, key=table, type="table")
db_node.children.append(table_node)
fields = db_conn.get_fields(table)
fields = db_conn.get_fields(table, db_name)
for field in fields:
table_node.children.append(
DataNode(
Expand Down
15 changes: 8 additions & 7 deletions dbgpt/datasource/rdbms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,16 +532,17 @@ def get_show_create_table(self, table_name):
ans = cursor.fetchall()
return ans[0][1]

def get_fields(self, table_name) -> List[Tuple]:
def get_fields(self, table_name, db_name=None) -> List[Tuple]:
"""Get column fields about specified table."""
session = self._db_sessions()
cursor = session.execute(
text(
"SELECT COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, "
"COLUMN_COMMENT from information_schema.COLUMNS where "
f"table_name='{table_name}'".format(table_name)
)
query = (
"SELECT COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, "
"COLUMN_COMMENT from information_schema.COLUMNS where "
f"table_name='{table_name}'"
)
if db_name is not None:
query += f" AND table_schema='{db_name}'"
cursor = session.execute(text(query))
fields = cursor.fetchall()
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]

Expand Down
5 changes: 3 additions & 2 deletions dbgpt/datasource/rdbms/conn_clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,16 +155,17 @@ def dialect(self) -> str:
"""Return string representation of dialect to use."""
return ""

def get_fields(self, table_name) -> List[Tuple]:
def get_fields(self, table_name, db_name=None) -> List[Tuple]:
"""Get column fields about specified table."""
session = self.client

_query_sql = f"""
SELECT name, type, default_expression, is_in_primary_key, comment
from system.columns where table='{table_name}'
""".format(
table_name
)
if db_name is not None:
_query_sql += f" AND database='{db_name}'"
with session.query_row_block_stream(_query_sql) as stream:
fields = [block for block in stream] # noqa
return fields
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/datasource/rdbms/conn_doris.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def get_columns(self, table_name: str) -> List[Dict]:
for field in fields
]

def get_fields(self, table_name) -> List[Tuple]:
def get_fields(self, table_name, db_name=None) -> List[Tuple]:
"""Get column fields about specified table."""
cursor = self.get_session().execute(
text(
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/datasource/rdbms/conn_postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def get_users(self):
logger.warning(f"postgresql get users error: {str(e)}")
return []

def get_fields(self, table_name) -> List[Tuple]:
def get_fields(self, table_name, db_name=None) -> List[Tuple]:
"""Get column fields about specified table."""
session = self._db_sessions()
cursor = session.execute(
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/datasource/rdbms/conn_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get_show_create_table(self, table_name):
ans = cursor.fetchall()
return ans[0][0]

def get_fields(self, table_name) -> List[Tuple]:
def get_fields(self, table_name, db_name=None) -> List[Tuple]:
"""Get column fields about specified table."""
cursor = self.session.execute(text(f"PRAGMA table_info('{table_name}')"))
fields = cursor.fetchall()
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/datasource/rdbms/conn_vertica.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def get_users(self):
logger.warning(f"vertica get users error: {str(e)}")
return []

def get_fields(self, table_name) -> List[Tuple]:
def get_fields(self, table_name, db_name=None) -> List[Tuple]:
"""Get column fields about specified table."""
session = self._db_sessions()
cursor = session.execute(
Expand Down

0 comments on commit 9fe060a

Please sign in to comment.