Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
fix: distinct doesn't work as a column wrapper (#275)
  • Loading branch information
jimfulton committed Aug 23, 2021
1 parent e06bf74 commit ad5baf8
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 2 deletions.
15 changes: 13 additions & 2 deletions sqlalchemy_bigquery/base.py
Expand Up @@ -247,7 +247,12 @@ def visit_insert(self, insert_stmt, asfrom=False, **kw):
)

def visit_column(
self, column, add_to_result_map=None, include_table=True, **kwargs
self,
column,
add_to_result_map=None,
include_table=True,
result_map_targets=(),
**kwargs,
):
name = orig_name = column.name
if name is None:
Expand All @@ -258,7 +263,12 @@ def visit_column(
name = self._truncated_identifier("colident", name)

if add_to_result_map is not None:
add_to_result_map(name, orig_name, (column, name, column.key), column.type)
targets = (column, name, column.key) + result_map_targets
if getattr(column, "_tq_label", None):
# _tq_label was added in SQLAlchemy 1.4
targets += (column._tq_label,)

add_to_result_map(name, orig_name, targets, column.type)

if is_literal:
name = self.escape_literal_column(name)
Expand All @@ -271,6 +281,7 @@ def visit_column(
tablename = table.name
if isinstance(tablename, elements._truncated_label):
tablename = self._truncated_identifier("alias", tablename)

return self.preparer.quote(tablename) + "." + name

def visit_label(self, *args, within_group_by=False, **kwargs):
Expand Down
6 changes: 6 additions & 0 deletions tests/system/conftest.py
Expand Up @@ -21,6 +21,7 @@
from typing import List

import pytest
import sqlalchemy

from google.cloud import bigquery
import test_utils.prefixer
Expand Down Expand Up @@ -137,3 +138,8 @@ def cleanup_datasets(bigquery_client: bigquery.Client):
bigquery_client.delete_dataset(
dataset, delete_contents=True, not_found_ok=True
)


@pytest.fixture
def metadata():
return sqlalchemy.MetaData()
34 changes: 34 additions & 0 deletions tests/system/test_sqlalchemy_bigquery.py
Expand Up @@ -691,3 +691,37 @@ def test_has_table(engine, engine_using_test_dataset, bigquery_dataset):
assert engine_using_test_dataset.has_table(f"{bigquery_dataset}.sample") is True

assert engine_using_test_dataset.has_table("sample_alt") is False


def test_distinct_188(engine, bigquery_dataset):
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, Integer
from sqlalchemy.orm import sessionmaker

Base = declarative_base()

class MyTable(Base):
__tablename__ = f"{bigquery_dataset}.test_distinct_188"
id = Column(Integer, primary_key=True)
my_column = Column(Integer)

MyTable.__table__.create(engine)

Session = sessionmaker(bind=engine)
db = Session()
db.add_all([MyTable(id=i, my_column=i % 2) for i in range(9)])
db.commit()

expected = [(0,), (1,)]

assert sorted(db.query(MyTable.my_column).distinct().all()) == expected
assert (
sorted(
db.query(
sqlalchemy.distinct(MyTable.my_column).label("just_a_random_label")
).all()
)
== expected
)

assert sorted(db.query(sqlalchemy.distinct(MyTable.my_column)).all()) == expected

0 comments on commit ad5baf8

Please sign in to comment.