Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: avoid aliasing known tables used in CTEs #369

Merged
merged 1 commit into from Oct 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 16 additions & 6 deletions sqlalchemy_bigquery/base.py
Expand Up @@ -51,6 +51,7 @@
from sqlalchemy.engine.base import Engine
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.selectable import CTE
from sqlalchemy.sql import elements, selectable
import re

Expand Down Expand Up @@ -254,6 +255,20 @@ def visit_table_valued_alias(self, element, **kw):
ret = f"{aliases}, {ret}"
return ret

def _known_tables(self):
known_tables = set()

for from_ in self.compile_state.froms:
if isinstance(from_, Table):
known_tables.add(from_.name)
elif isinstance(from_, CTE):
for column in from_.original.selected_columns:
table = getattr(column, "table", None)
if table is not None:
known_tables.add(table.name)

return known_tables

def visit_column(
self,
column,
Expand Down Expand Up @@ -290,12 +305,7 @@ def visit_column(
if isinstance(tablename, elements._truncated_label):
tablename = self._truncated_identifier("alias", tablename)
elif TABLE_VALUED_ALIAS_ALIASES in kwargs:
known_tables = set(
from_.name
for from_ in self.compile_state.froms
if isinstance(from_, Table)
)
if tablename not in known_tables:
if tablename not in self._known_tables():
aliases = kwargs[TABLE_VALUED_ALIAS_ALIASES]
if tablename not in aliases:
aliases[tablename] = self.anon_map[
Expand Down
36 changes: 36 additions & 0 deletions tests/unit/test_compiler.py
Expand Up @@ -76,3 +76,39 @@ def test_no_alias_for_known_tables(faux_conn, metadata):
)
found_sql = q.compile(faux_conn).string
assert found_sql == expected_sql


@sqlalchemy_1_4_or_higher
def test_no_alias_for_known_tables_cte(faux_conn, metadata):
# See: https://github.com/googleapis/python-bigquery-sqlalchemy/issues/368
table = setup_table(
faux_conn,
"table1",
metadata,
sqlalchemy.Column("foo", sqlalchemy.Integer),
sqlalchemy.Column("bars", sqlalchemy.ARRAY(sqlalchemy.Integer)),
)
F = sqlalchemy.func

# Set up initiali query
q = sqlalchemy.select(table.c.foo, F.unnest(table.c.bars).column_valued("bar"))

expected_initial_sql = (
"SELECT `table1`.`foo`, `bar` \n"
"FROM `table1`, unnest(`table1`.`bars`) AS `bar`"
)
found_initial_sql = q.compile(faux_conn).string
assert found_initial_sql == expected_initial_sql

q = q.cte("cte")
q = sqlalchemy.select(*q.columns)

expected_cte_sql = (
"WITH `cte` AS \n"
"(SELECT `table1`.`foo` AS `foo`, `bar` \n"
"FROM `table1`, unnest(`table1`.`bars`) AS `bar`)\n"
" SELECT `cte`.`foo`, `cte`.`bar` \n"
"FROM `cte`"
)
found_cte_sql = q.compile(faux_conn).string
assert found_cte_sql == expected_cte_sql