diff --git a/sqlalchemy_bigquery/base.py b/sqlalchemy_bigquery/base.py index f2da562d..ae96d6f4 100644 --- a/sqlalchemy_bigquery/base.py +++ b/sqlalchemy_bigquery/base.py @@ -51,6 +51,7 @@ from sqlalchemy.engine.base import Engine from sqlalchemy.sql.schema import Column from sqlalchemy.sql.schema import Table +from sqlalchemy.sql.selectable import CTE from sqlalchemy.sql import elements, selectable import re @@ -254,6 +255,20 @@ def visit_table_valued_alias(self, element, **kw): ret = f"{aliases}, {ret}" return ret + def _known_tables(self): + known_tables = set() + + for from_ in self.compile_state.froms: + if isinstance(from_, Table): + known_tables.add(from_.name) + elif isinstance(from_, CTE): + for column in from_.original.selected_columns: + table = getattr(column, "table", None) + if table is not None: + known_tables.add(table.name) + + return known_tables + def visit_column( self, column, @@ -290,12 +305,7 @@ def visit_column( if isinstance(tablename, elements._truncated_label): tablename = self._truncated_identifier("alias", tablename) elif TABLE_VALUED_ALIAS_ALIASES in kwargs: - known_tables = set( - from_.name - for from_ in self.compile_state.froms - if isinstance(from_, Table) - ) - if tablename not in known_tables: + if tablename not in self._known_tables(): aliases = kwargs[TABLE_VALUED_ALIAS_ALIASES] if tablename not in aliases: aliases[tablename] = self.anon_map[ diff --git a/tests/unit/test_compiler.py b/tests/unit/test_compiler.py index 5da4e935..889ad63d 100644 --- a/tests/unit/test_compiler.py +++ b/tests/unit/test_compiler.py @@ -76,3 +76,39 @@ def test_no_alias_for_known_tables(faux_conn, metadata): ) found_sql = q.compile(faux_conn).string assert found_sql == expected_sql + + +@sqlalchemy_1_4_or_higher +def test_no_alias_for_known_tables_cte(faux_conn, metadata): + # See: https://github.com/googleapis/python-bigquery-sqlalchemy/issues/368 + table = setup_table( + faux_conn, + "table1", + metadata, + sqlalchemy.Column("foo", sqlalchemy.Integer), + sqlalchemy.Column("bars", sqlalchemy.ARRAY(sqlalchemy.Integer)), + ) + F = sqlalchemy.func + + # Set up initiali query + q = sqlalchemy.select(table.c.foo, F.unnest(table.c.bars).column_valued("bar")) + + expected_initial_sql = ( + "SELECT `table1`.`foo`, `bar` \n" + "FROM `table1`, unnest(`table1`.`bars`) AS `bar`" + ) + found_initial_sql = q.compile(faux_conn).string + assert found_initial_sql == expected_initial_sql + + q = q.cte("cte") + q = sqlalchemy.select(*q.columns) + + expected_cte_sql = ( + "WITH `cte` AS \n" + "(SELECT `table1`.`foo` AS `foo`, `bar` \n" + "FROM `table1`, unnest(`table1`.`bars`) AS `bar`)\n" + " SELECT `cte`.`foo`, `cte`.`bar` \n" + "FROM `cte`" + ) + found_cte_sql = q.compile(faux_conn).string + assert found_cte_sql == expected_cte_sql