Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support CSV format in load_table_from_dataframe pandas connector #399

Merged
101 changes: 67 additions & 34 deletions google/cloud/bigquery/client.py
Expand Up @@ -2197,15 +2197,16 @@ def load_table_from_dataframe(
else:
job_config = job.LoadJobConfig()

if job_config.source_format:
if job_config.source_format != job.SourceFormat.PARQUET:
raise ValueError(
"Got unexpected source_format: '{}'. Currently, only PARQUET is supported".format(
job_config.source_format
)
)
else:
supported_formats = (job.SourceFormat.CSV, job.SourceFormat.PARQUET)
cguardia marked this conversation as resolved.
Show resolved Hide resolved
if job_config.source_format is None:
# default value
job_config.source_format = job.SourceFormat.PARQUET
if job_config.source_format not in supported_formats:
raise ValueError(
"Got unexpected source_format: '{}'. Currently, only PARQUET and CSV are supported".format(
job_config.source_format
)
)

if location is None:
location = self.location
Expand Down Expand Up @@ -2245,39 +2246,71 @@ def load_table_from_dataframe(
stacklevel=2,
)

tmpfd, tmppath = tempfile.mkstemp(suffix="_job_{}.parquet".format(job_id[:8]))
tmpfd, tmppath = tempfile.mkstemp(
suffix="_job_{}.{}".format(job_id[:8], job_config.source_format.lower(),)
cguardia marked this conversation as resolved.
Show resolved Hide resolved
)
os.close(tmpfd)

try:
if job_config.schema:
if parquet_compression == "snappy": # adjust the default value
parquet_compression = parquet_compression.upper()

_pandas_helpers.dataframe_to_parquet(
dataframe,
job_config.schema,
tmppath,
parquet_compression=parquet_compression,
)
if job_config.source_format == job.SourceFormat.PARQUET:

if job_config.schema:
if parquet_compression == "snappy": # adjust the default value
parquet_compression = parquet_compression.upper()

_pandas_helpers.dataframe_to_parquet(
dataframe,
job_config.schema,
tmppath,
parquet_compression=parquet_compression,
)
else:
dataframe.to_parquet(tmppath, compression=parquet_compression)

with open(tmppath, "rb") as parquet_file:
cguardia marked this conversation as resolved.
Show resolved Hide resolved
file_size = os.path.getsize(tmppath)
return self.load_table_from_file(
parquet_file,
destination,
num_retries=num_retries,
rewind=True,
size=file_size,
job_id=job_id,
job_id_prefix=job_id_prefix,
location=location,
project=project,
job_config=job_config,
timeout=timeout,
)

else:
dataframe.to_parquet(tmppath, compression=parquet_compression)

with open(tmppath, "rb") as parquet_file:
file_size = os.path.getsize(tmppath)
return self.load_table_from_file(
parquet_file,
destination,
num_retries=num_retries,
rewind=True,
size=file_size,
job_id=job_id,
job_id_prefix=job_id_prefix,
location=location,
project=project,
job_config=job_config,
timeout=timeout,

dataframe.to_csv(
tmppath,
index=False,
header=False,
encoding="utf-8",
float_format="%.17g",
date_format="%Y-%m-%d %H:%M:%S.%f",
)

with open(tmppath, "rb") as csv_file:
file_size = os.path.getsize(tmppath)
return self.load_table_from_file(
csv_file,
destination,
num_retries=num_retries,
rewind=True,
size=file_size,
job_id=job_id,
job_id_prefix=job_id_prefix,
location=location,
project=project,
job_config=job_config,
timeout=timeout,
)

finally:
os.remove(tmppath)

Expand Down
97 changes: 97 additions & 0 deletions tests/system.py
Expand Up @@ -1165,6 +1165,103 @@ def test_load_table_from_json_basic_use(self):
self.assertEqual(tuple(table.schema), table_schema)
self.assertEqual(table.num_rows, 2)

@unittest.skipIf(pandas is None, "Requires `pandas`")
def test_load_table_from_dataframe_w_explicit_schema_source_format_csv(self):
from google.cloud.bigquery.job import SourceFormat

scalars_schema = (
bigquery.SchemaField("bool_col", "BOOLEAN"),
bigquery.SchemaField("bytes_col", "BYTES"),
bigquery.SchemaField("date_col", "DATE"),
bigquery.SchemaField("dt_col", "DATETIME"),
bigquery.SchemaField("float_col", "FLOAT"),
bigquery.SchemaField("geo_col", "GEOGRAPHY"),
bigquery.SchemaField("int_col", "INTEGER"),
bigquery.SchemaField("num_col", "NUMERIC"),
bigquery.SchemaField("str_col", "STRING"),
bigquery.SchemaField("time_col", "TIME"),
bigquery.SchemaField("ts_col", "TIMESTAMP"),
)
table_schema = scalars_schema + (
# TODO: Array columns can't be read due to NULLABLE versus REPEATED
cguardia marked this conversation as resolved.
Show resolved Hide resolved
# mode mismatch. See:
# https://issuetracker.google.com/133415569#comment3
# bigquery.SchemaField("array_col", "INTEGER", mode="REPEATED"),
# TODO: Support writing StructArrays to Parquet. See:
# https://jira.apache.org/jira/browse/ARROW-2587
# bigquery.SchemaField("struct_col", "RECORD", fields=scalars_schema),
)
df_data = collections.OrderedDict(
[
("bool_col", [True, None, False]),
("bytes_col", ["abc", None, "def"]),
(
"date_col",
[datetime.date(1, 1, 1), None, datetime.date(9999, 12, 31)],
),
(
"dt_col",
[
datetime.datetime(1, 1, 1, 0, 0, 0),
None,
datetime.datetime(9999, 12, 31, 23, 59, 59, 999999),
],
),
("float_col", [float("-inf"), float("nan"), float("inf")]),
cguardia marked this conversation as resolved.
Show resolved Hide resolved
(
"geo_col",
[
"POINT(30 10)",
None,
"POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))",
],
),
("int_col", [-9223372036854775808, None, 9223372036854775807]),
(
"num_col",
[
decimal.Decimal("-99999999999999999999999999999.999999999"),
None,
decimal.Decimal("99999999999999999999999999999.999999999"),
],
),
("str_col", [u"abc", None, u"def"]),
(
"time_col",
[datetime.time(0, 0, 0), None, datetime.time(23, 59, 59, 999999)],
),
(
"ts_col",
[
datetime.datetime(1, 1, 1, 0, 0, 0, tzinfo=pytz.utc),
None,
datetime.datetime(
9999, 12, 31, 23, 59, 59, 999999, tzinfo=pytz.utc
),
],
),
]
)
dataframe = pandas.DataFrame(df_data, dtype="object", columns=df_data.keys())

dataset_id = _make_dataset_id("bq_load_test")
self.temp_dataset(dataset_id)
table_id = "{}.{}.load_table_from_dataframe_w_explicit_schema_csv".format(
Config.CLIENT.project, dataset_id
)

job_config = bigquery.LoadJobConfig(
schema=table_schema, source_format=SourceFormat.CSV
)
load_job = Config.CLIENT.load_table_from_dataframe(
dataframe, table_id, job_config=job_config
)
load_job.result()

table = Config.CLIENT.get_table(table_id)
self.assertEqual(tuple(table.schema), table_schema)
self.assertEqual(table.num_rows, 3)

def test_load_table_from_json_schema_autodetect(self):
json_rows = [
{"name": "John", "age": 18, "birthday": "2001-10-15", "is_awesome": False},
Expand Down
50 changes: 50 additions & 0 deletions tests/unit/test_client.py
Expand Up @@ -8366,6 +8366,56 @@ def test_load_table_from_dataframe_w_invaild_job_config(self):
err_msg = str(exc.value)
assert "Expected an instance of LoadJobConfig" in err_msg

@unittest.skipIf(pandas is None, "Requires `pandas`")
def test_load_table_from_dataframe_with_csv_source_format(self):
from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES
from google.cloud.bigquery import job
from google.cloud.bigquery.schema import SchemaField

client = self._make_client()
records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}]
dataframe = pandas.DataFrame(records)
job_config = job.LoadJobConfig(
write_disposition=job.WriteDisposition.WRITE_TRUNCATE,
source_format=job.SourceFormat.CSV,
)

get_table_patch = mock.patch(
"google.cloud.bigquery.client.Client.get_table",
autospec=True,
return_value=mock.Mock(
schema=[SchemaField("id", "INTEGER"), SchemaField("age", "INTEGER")]
),
)
load_patch = mock.patch(
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
)
with load_patch as load_table_from_file, get_table_patch:
client.load_table_from_dataframe(
dataframe, self.TABLE_REF, job_config=job_config
)

load_table_from_file.assert_called_once_with(
client,
mock.ANY,
self.TABLE_REF,
num_retries=_DEFAULT_NUM_RETRIES,
rewind=True,
size=mock.ANY,
job_id=mock.ANY,
job_id_prefix=None,
location=None,
project=None,
job_config=mock.ANY,
timeout=None,
)

sent_file = load_table_from_file.mock_calls[0][1][1]
assert sent_file.closed

sent_config = load_table_from_file.mock_calls[0][2]["job_config"]
assert sent_config.source_format == job.SourceFormat.CSV

def test_load_table_from_json_basic_use(self):
from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES
from google.cloud.bigquery import job
Expand Down