diff --git a/sqlalchemy_bigquery/base.py b/sqlalchemy_bigquery/base.py index 12ab3e36..d1b0a75b 100644 --- a/sqlalchemy_bigquery/base.py +++ b/sqlalchemy_bigquery/base.py @@ -247,7 +247,12 @@ def visit_insert(self, insert_stmt, asfrom=False, **kw): ) def visit_column( - self, column, add_to_result_map=None, include_table=True, **kwargs + self, + column, + add_to_result_map=None, + include_table=True, + result_map_targets=(), + **kwargs, ): name = orig_name = column.name if name is None: @@ -258,7 +263,12 @@ def visit_column( name = self._truncated_identifier("colident", name) if add_to_result_map is not None: - add_to_result_map(name, orig_name, (column, name, column.key), column.type) + targets = (column, name, column.key) + result_map_targets + if getattr(column, "_tq_label", None): + # _tq_label was added in SQLAlchemy 1.4 + targets += (column._tq_label,) + + add_to_result_map(name, orig_name, targets, column.type) if is_literal: name = self.escape_literal_column(name) @@ -271,6 +281,7 @@ def visit_column( tablename = table.name if isinstance(tablename, elements._truncated_label): tablename = self._truncated_identifier("alias", tablename) + return self.preparer.quote(tablename) + "." + name def visit_label(self, *args, within_group_by=False, **kwargs): diff --git a/tests/system/conftest.py b/tests/system/conftest.py index 077f06a8..d9db14ab 100644 --- a/tests/system/conftest.py +++ b/tests/system/conftest.py @@ -21,6 +21,7 @@ from typing import List import pytest +import sqlalchemy from google.cloud import bigquery import test_utils.prefixer @@ -137,3 +138,8 @@ def cleanup_datasets(bigquery_client: bigquery.Client): bigquery_client.delete_dataset( dataset, delete_contents=True, not_found_ok=True ) + + +@pytest.fixture +def metadata(): + return sqlalchemy.MetaData() diff --git a/tests/system/test_sqlalchemy_bigquery.py b/tests/system/test_sqlalchemy_bigquery.py index 4380b0be..1390d11c 100644 --- a/tests/system/test_sqlalchemy_bigquery.py +++ b/tests/system/test_sqlalchemy_bigquery.py @@ -691,3 +691,37 @@ def test_has_table(engine, engine_using_test_dataset, bigquery_dataset): assert engine_using_test_dataset.has_table(f"{bigquery_dataset}.sample") is True assert engine_using_test_dataset.has_table("sample_alt") is False + + +def test_distinct_188(engine, bigquery_dataset): + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy import Column, Integer + from sqlalchemy.orm import sessionmaker + + Base = declarative_base() + + class MyTable(Base): + __tablename__ = f"{bigquery_dataset}.test_distinct_188" + id = Column(Integer, primary_key=True) + my_column = Column(Integer) + + MyTable.__table__.create(engine) + + Session = sessionmaker(bind=engine) + db = Session() + db.add_all([MyTable(id=i, my_column=i % 2) for i in range(9)]) + db.commit() + + expected = [(0,), (1,)] + + assert sorted(db.query(MyTable.my_column).distinct().all()) == expected + assert ( + sorted( + db.query( + sqlalchemy.distinct(MyTable.my_column).label("just_a_random_label") + ).all() + ) + == expected + ) + + assert sorted(db.query(sqlalchemy.distinct(MyTable.my_column)).all()) == expected