Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to use the Datasource as a TEMP TABLE. #224

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3620,7 +3620,7 @@ def bootstrap_by_row(
)
random_choice_table = sql.Sql(columns, sql.Join(table, replicates))
random_choice_table_alias = with_data.add(
sql.Datasource(random_choice_table, 'BootstrapRandomRows'))
sql.Datasource(random_choice_table, 'BootstrapRandomRows', True))

using = (
sql.Columns(partition)
Expand Down
75 changes: 69 additions & 6 deletions sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,14 +496,23 @@ def __str__(self):
class Datasource(SqlComponent):
"""Represents a SQL datasource, could be a table name or a SQL query."""

def __init__(self, table, alias=None):
def __init__(self, table, alias=None, as_temp_table=False):
"""Initializes a Datasource.

Args:
table: A string representing a table name or a SQL query.
alias: The alias of the table.
as_temp_table: Whether to create a TEMP TABLE for `table`. If not, the
`table` will be treated as a subquery in the WITH clause.
"""
super(Datasource, self).__init__()
self.table = table
self.alias = alias
if isinstance(table, Datasource):
self.table = table.table
self.alias = alias or table.alias
self.alias = escape_alias(self.alias)
self.as_temp_table = as_temp_table
self.is_table = (
not str(self.table).strip().upper().startswith('SELECT')
and 'WITH ' not in str(self.table).upper()
Expand All @@ -527,6 +536,8 @@ def join(self, other, on=None, using=None, join='', alias=None):
return Join(self, other, on, using, join, alias)

def __str__(self):
if self.as_temp_table and not self.alias:
raise ValueError('Datasource as a TEMP TABLE must have an alias!')
table = self.table if self.is_table else '(%s)' % self.table
return '%s AS %s' % (table, self.alias) if self.alias else str(table)

Expand Down Expand Up @@ -571,11 +582,15 @@ class Datasources(SqlComponents):
def __init__(self, datasources=None):
super(Datasources, self).__init__()
self.children = collections.OrderedDict()
self.temp_tables = set()
self.add(datasources)

@property
def datasources(self):
return (Datasource(v, k) for k, v in self.children.items())
return (
Datasource(v, k, k in self.temp_tables)
for k, v in self.children.items()
)

def merge(self, new_child: Union[Datasource, 'Datasources', 'Sql']):
"""Merges a datasource if possible.
Expand Down Expand Up @@ -614,23 +629,33 @@ def merge(self, new_child: Union[Datasource, 'Datasources', 'Sql']):
raise ValueError(
'%s is a %s, not a Datasource! You can try .add() instead.' %
(new_child, type(new_child)))
alias, table = new_child.alias, new_child.table
alias, table, as_temp_table = (
new_child.alias,
new_child.table,
new_child.as_temp_table,
)
# If there is a compatible data, most likely it has the same alias.
if alias in self.children:
target = self.children[alias]
if isinstance(target, Sql):
merged = target.merge(table)
if merged:
if as_temp_table:
self.add_temp_table(alias)
return alias
for a, t in self.children.items():
if a == alias or not isinstance(t, Sql):
continue
merged = t.merge(table)
if merged:
if as_temp_table:
self.add_temp_table(a)
return a
while new_child.alias in self.children:
new_child.alias = add_suffix(new_child.alias)
self.children[new_child.alias] = table
if as_temp_table:
self.add_temp_table(new_child.alias)
return new_child.alias

def add(self, children: Union[Datasource, Iterable[Datasource]]):
Expand Down Expand Up @@ -663,19 +688,46 @@ def add(self, children: Union[Datasource, Iterable[Datasource]]):
return
if not isinstance(children, Datasource):
raise ValueError('Not a Datasource!')
alias, table = children.alias, children.table
alias, table, as_temp_table = (
children.alias,
children.table,
children.as_temp_table,
)
if alias not in self.children:
if table not in self.children.values():
self.children[alias] = table
if as_temp_table:
self.add_temp_table(alias)
return alias
children.alias = [k for k, v in self.children.items() if v == table][0]
if as_temp_table:
self.add_temp_table(children.alias)
return children.alias
else:
if table == self.children[alias]:
if as_temp_table:
self.add_temp_table(alias)
return alias
children.alias = add_suffix(alias)
return self.add(children)

def add_temp_table(self, table: Union[str, 'Sql', Join, Datasource]):
"""Marks alias and all its data dependencies as temp tables."""
if isinstance(table, str):
self.temp_tables.add(table)
if table in self.children:
self.add_temp_table(self.children[table])
return
if isinstance(table, Join):
self.add_temp_table(table.ds1)
self.add_temp_table(table.ds2)
return
if isinstance(table, Datasource):
return self.add_temp_table(table.table)
if isinstance(table, Sql):
return self.add_temp_table(table.from_data)
return self

def extend(self, other: 'Datasources'):
"""Merge other to self. Adjust the query if a new alias is needed."""
datasources = list(other.datasources)
Expand All @@ -691,7 +743,18 @@ def extend(self, other: 'Datasources'):
return self

def __str__(self):
return ',\n'.join((d.get_expression('WITH') for d in self.datasources if d))
temp_tables = []
with_tables = []
for d in self.datasources:
expression = d.get_expression('WITH')
if d.as_temp_table:
temp_tables.append(f'CREATE OR REPLACE TEMP TABLE {expression};')
else:
with_tables.append(expression)
res = '\n'.join(temp_tables)
if with_tables:
res += '\nWITH\n' + ',\n'.join(with_tables)
return res


class Sql(SqlComponent):
Expand Down Expand Up @@ -766,7 +829,7 @@ def merge(self, other: 'Sql'):
return True

def __str__(self):
with_clause = 'WITH\n%s' % self.with_data if self.with_data else None
with_clause = str(self.with_data) if self.with_data else None
all_columns = self.all_columns or '*'
select_clause = f'SELECT\n{all_columns}'
from_clause = ('FROM %s'
Expand Down
Loading