Skip to content

Commit

Permalink
Allow to use the Datasource as a TEMP TABLE.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 710112682
  • Loading branch information
tcya authored and meterstick-copybara committed Dec 28, 2024
1 parent 3a70436 commit e306ef8
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 7 deletions.
2 changes: 1 addition & 1 deletion operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3620,7 +3620,7 @@ def bootstrap_by_row(
)
random_choice_table = sql.Sql(columns, sql.Join(table, replicates))
random_choice_table_alias = with_data.add(
sql.Datasource(random_choice_table, 'BootstrapRandomRows'))
sql.Datasource(random_choice_table, 'BootstrapRandomRows', True))

using = (
sql.Columns(partition)
Expand Down
75 changes: 69 additions & 6 deletions sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,14 +496,23 @@ def __str__(self):
class Datasource(SqlComponent):
"""Represents a SQL datasource, could be a table name or a SQL query."""

def __init__(self, table, alias=None):
def __init__(self, table, alias=None, as_temp_table=False):
"""Initializes a Datasource.
Args:
table: A string representing a table name or a SQL query.
alias: The alias of the table.
as_temp_table: Whether to create a TEMP TABLE for `table`. If not, the
`table` will be treated as a subquery in the WITH clause.
"""
super(Datasource, self).__init__()
self.table = table
self.alias = alias
if isinstance(table, Datasource):
self.table = table.table
self.alias = alias or table.alias
self.alias = escape_alias(self.alias)
self.as_temp_table = as_temp_table
self.is_table = (
not str(self.table).strip().upper().startswith('SELECT')
and 'WITH ' not in str(self.table).upper()
Expand All @@ -527,6 +536,8 @@ def join(self, other, on=None, using=None, join='', alias=None):
return Join(self, other, on, using, join, alias)

def __str__(self):
if self.as_temp_table and not self.alias:
raise ValueError('Datasource as a TEMP TABLE must have an alias!')
table = self.table if self.is_table else '(%s)' % self.table
return '%s AS %s' % (table, self.alias) if self.alias else str(table)

Expand Down Expand Up @@ -571,11 +582,15 @@ class Datasources(SqlComponents):
def __init__(self, datasources=None):
super(Datasources, self).__init__()
self.children = collections.OrderedDict()
self.temp_tables = set()
self.add(datasources)

@property
def datasources(self):
return (Datasource(v, k) for k, v in self.children.items())
return (
Datasource(v, k, k in self.temp_tables)
for k, v in self.children.items()
)

def merge(self, new_child: Union[Datasource, 'Datasources', 'Sql']):
"""Merges a datasource if possible.
Expand Down Expand Up @@ -614,23 +629,33 @@ def merge(self, new_child: Union[Datasource, 'Datasources', 'Sql']):
raise ValueError(
'%s is a %s, not a Datasource! You can try .add() instead.' %
(new_child, type(new_child)))
alias, table = new_child.alias, new_child.table
alias, table, as_temp_table = (
new_child.alias,
new_child.table,
new_child.as_temp_table,
)
# If there is a compatible data, most likely it has the same alias.
if alias in self.children:
target = self.children[alias]
if isinstance(target, Sql):
merged = target.merge(table)
if merged:
if as_temp_table:
self.add_temp_table(alias)
return alias
for a, t in self.children.items():
if a == alias or not isinstance(t, Sql):
continue
merged = t.merge(table)
if merged:
if as_temp_table:
self.add_temp_table(a)
return a
while new_child.alias in self.children:
new_child.alias = add_suffix(new_child.alias)
self.children[new_child.alias] = table
if as_temp_table:
self.add_temp_table(new_child.alias)
return new_child.alias

def add(self, children: Union[Datasource, Iterable[Datasource]]):
Expand Down Expand Up @@ -663,19 +688,46 @@ def add(self, children: Union[Datasource, Iterable[Datasource]]):
return
if not isinstance(children, Datasource):
raise ValueError('Not a Datasource!')
alias, table = children.alias, children.table
alias, table, as_temp_table = (
children.alias,
children.table,
children.as_temp_table,
)
if alias not in self.children:
if table not in self.children.values():
self.children[alias] = table
if as_temp_table:
self.add_temp_table(alias)
return alias
children.alias = [k for k, v in self.children.items() if v == table][0]
if as_temp_table:
self.add_temp_table(children.alias)
return children.alias
else:
if table == self.children[alias]:
if as_temp_table:
self.add_temp_table(alias)
return alias
children.alias = add_suffix(alias)
return self.add(children)

def add_temp_table(self, table: Union[str, 'Sql', Join, Datasource]):
"""Marks alias and all its data dependencies as temp tables."""
if isinstance(table, str):
self.temp_tables.add(table)
if table in self.children:
self.add_temp_table(self.children[table])
return
if isinstance(table, Join):
self.add_temp_table(table.ds1)
self.add_temp_table(table.ds2)
return
if isinstance(table, Datasource):
return self.add_temp_table(table.table)
if isinstance(table, Sql):
return self.add_temp_table(table.from_data)
return self

def extend(self, other: 'Datasources'):
"""Merge other to self. Adjust the query if a new alias is needed."""
datasources = list(other.datasources)
Expand All @@ -691,7 +743,18 @@ def extend(self, other: 'Datasources'):
return self

def __str__(self):
return ',\n'.join((d.get_expression('WITH') for d in self.datasources if d))
temp_tables = []
with_tables = []
for d in self.datasources:
expression = d.get_expression('WITH')
if d.as_temp_table:
temp_tables.append(f'CREATE OR REPLACE TEMP TABLE {expression};')
else:
with_tables.append(expression)
res = '\n'.join(temp_tables)
if with_tables:
res += '\nWITH\n' + ',\n'.join(with_tables)
return res


class Sql(SqlComponent):
Expand Down Expand Up @@ -766,7 +829,7 @@ def merge(self, other: 'Sql'):
return True

def __str__(self):
with_clause = 'WITH\n%s' % self.with_data if self.with_data else None
with_clause = str(self.with_data) if self.with_data else None
all_columns = self.all_columns or '*'
select_clause = f'SELECT\n{all_columns}'
from_clause = ('FROM %s'
Expand Down

0 comments on commit e306ef8

Please sign in to comment.