Skip to content

Commit

Permalink
fix: avoid aliasing known tables used in CTEs (#369)
Browse files Browse the repository at this point in the history
Toward #368.
  • Loading branch information
tseaver committed Oct 29, 2021
1 parent f6ad9c9 commit 4b05d21
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 6 deletions.
22 changes: 16 additions & 6 deletions sqlalchemy_bigquery/base.py
Expand Up @@ -51,6 +51,7 @@
from sqlalchemy.engine.base import Engine
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.selectable import CTE
from sqlalchemy.sql import elements, selectable
import re

Expand Down Expand Up @@ -254,6 +255,20 @@ def visit_table_valued_alias(self, element, **kw):
ret = f"{aliases}, {ret}"
return ret

def _known_tables(self):
known_tables = set()

for from_ in self.compile_state.froms:
if isinstance(from_, Table):
known_tables.add(from_.name)
elif isinstance(from_, CTE):
for column in from_.original.selected_columns:
table = getattr(column, "table", None)
if table is not None:
known_tables.add(table.name)

return known_tables

def visit_column(
self,
column,
Expand Down Expand Up @@ -290,12 +305,7 @@ def visit_column(
if isinstance(tablename, elements._truncated_label):
tablename = self._truncated_identifier("alias", tablename)
elif TABLE_VALUED_ALIAS_ALIASES in kwargs:
known_tables = set(
from_.name
for from_ in self.compile_state.froms
if isinstance(from_, Table)
)
if tablename not in known_tables:
if tablename not in self._known_tables():
aliases = kwargs[TABLE_VALUED_ALIAS_ALIASES]
if tablename not in aliases:
aliases[tablename] = self.anon_map[
Expand Down
36 changes: 36 additions & 0 deletions tests/unit/test_compiler.py
Expand Up @@ -76,3 +76,39 @@ def test_no_alias_for_known_tables(faux_conn, metadata):
)
found_sql = q.compile(faux_conn).string
assert found_sql == expected_sql


@sqlalchemy_1_4_or_higher
def test_no_alias_for_known_tables_cte(faux_conn, metadata):
# See: https://github.com/googleapis/python-bigquery-sqlalchemy/issues/368
table = setup_table(
faux_conn,
"table1",
metadata,
sqlalchemy.Column("foo", sqlalchemy.Integer),
sqlalchemy.Column("bars", sqlalchemy.ARRAY(sqlalchemy.Integer)),
)
F = sqlalchemy.func

# Set up initiali query
q = sqlalchemy.select(table.c.foo, F.unnest(table.c.bars).column_valued("bar"))

expected_initial_sql = (
"SELECT `table1`.`foo`, `bar` \n"
"FROM `table1`, unnest(`table1`.`bars`) AS `bar`"
)
found_initial_sql = q.compile(faux_conn).string
assert found_initial_sql == expected_initial_sql

q = q.cte("cte")
q = sqlalchemy.select(*q.columns)

expected_cte_sql = (
"WITH `cte` AS \n"
"(SELECT `table1`.`foo` AS `foo`, `bar` \n"
"FROM `table1`, unnest(`table1`.`bars`) AS `bar`)\n"
" SELECT `cte`.`foo`, `cte`.`bar` \n"
"FROM `cte`"
)
found_cte_sql = q.compile(faux_conn).string
assert found_cte_sql == expected_cte_sql

0 comments on commit 4b05d21

Please sign in to comment.