Skip to content

Commit

Permalink
ENH: to_sql() add parameter "method" to control insertions method (pa…
Browse files Browse the repository at this point in the history
…ndas-dev#8… (pandas-dev#21401)

* ENH: to_sql() add parameter "method" to control insertions method (pandas-dev#8953)

* ENH: to_sql() add parameter "method". Fix docstrings (pandas-dev#8953)

* ENH: to_sql() add parameter "method". Improve docs based on reviews (pandas-dev#8953)

* ENH: to_sql() add parameter "method". Fix unit-test (pandas-dev#8953)

* doc clean-up

* additional doc clean-up

* use dict(zip()) directly

* clean up merge

* default --> None

* Remove stray default

* Remove method kwarg

* change default to None

* test copy insert snippit

* print debug

* index=False

* Add reference to documentation
  • Loading branch information
schettino72 authored and Pingviinituutti committed Feb 28, 2019
1 parent 0277ee7 commit 31d2736
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 18 deletions.
48 changes: 48 additions & 0 deletions doc/source/io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4989,6 +4989,54 @@ with respect to the timezone.
timezone aware or naive. When reading ``TIMESTAMP WITH TIME ZONE`` types, pandas
will convert the data to UTC.

.. _io.sql.method:

Insertion Method
++++++++++++++++

.. versionadded:: 0.24.0

The parameter ``method`` controls the SQL insertion clause used.
Possible values are:

- ``None``: Uses standard SQL ``INSERT`` clause (one per row).
- ``'multi'``: Pass multiple values in a single ``INSERT`` clause.
It uses a *special* SQL syntax not supported by all backends.
This usually provides better performance for analytic databases
like *Presto* and *Redshift*, but has worse performance for
traditional SQL backend if the table contains many columns.
For more information check the SQLAlchemy `documention
<http://docs.sqlalchemy.org/en/latest/core/dml.html#sqlalchemy.sql.expression.Insert.values.params.*args>`__.
- callable with signature ``(pd_table, conn, keys, data_iter)``:
This can be used to implement a more performant insertion method based on
specific backend dialect features.

Example of a callable using PostgreSQL `COPY clause
<https://www.postgresql.org/docs/current/static/sql-copy.html>`__::

# Alternative to_sql() *method* for DBs that support COPY FROM
import csv
from io import StringIO

def psql_insert_copy(table, conn, keys, data_iter):
# gets a DBAPI connection that can provide a cursor
dbapi_conn = conn.connection
with dbapi_conn.cursor() as cur:
s_buf = StringIO()
writer = csv.writer(s_buf)
writer.writerows(data_iter)
s_buf.seek(0)

columns = ', '.join('"{}"'.format(k) for k in keys)
if table.schema:
table_name = '{}.{}'.format(table.schema, table.name)
else:
table_name = table.name

sql = 'COPY {} ({}) FROM STDIN WITH CSV'.format(
table_name, columns)
cur.copy_expert(sql=sql, file=s_buf)

Reading Tables
''''''''''''''

Expand Down
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.24.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ Other Enhancements
- :meth:`DataFrame.between_time` and :meth:`DataFrame.at_time` have gained the ``axis`` parameter (:issue:`8839`)
- The ``scatter_matrix``, ``andrews_curves``, ``parallel_coordinates``, ``lag_plot``, ``autocorrelation_plot``, ``bootstrap_plot``, and ``radviz`` plots from the ``pandas.plotting`` module are now accessible from calling :meth:`DataFrame.plot` (:issue:`11978`)
- :class:`IntervalIndex` has gained the :attr:`~IntervalIndex.is_overlapping` attribute to indicate if the ``IntervalIndex`` contains any overlapping intervals (:issue:`23309`)
- :func:`pandas.DataFrame.to_sql` has gained the ``method`` argument to control SQL insertion clause. See the :ref:`insertion method <io.sql.method>` section in the documentation. (:issue:`8953`)

.. _whatsnew_0240.api_breaking:

Expand Down
15 changes: 13 additions & 2 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2386,7 +2386,7 @@ def to_msgpack(self, path_or_buf=None, encoding='utf-8', **kwargs):
**kwargs)

def to_sql(self, name, con, schema=None, if_exists='fail', index=True,
index_label=None, chunksize=None, dtype=None):
index_label=None, chunksize=None, dtype=None, method=None):
"""
Write records stored in a DataFrame to a SQL database.
Expand Down Expand Up @@ -2424,6 +2424,17 @@ def to_sql(self, name, con, schema=None, if_exists='fail', index=True,
Specifying the datatype for columns. The keys should be the column
names and the values should be the SQLAlchemy types or strings for
the sqlite3 legacy mode.
method : {None, 'multi', callable}, default None
Controls the SQL insertion clause used:
* None : Uses standard SQL ``INSERT`` clause (one per row).
* 'multi': Pass multiple values in a single ``INSERT`` clause.
* callable with signature ``(pd_table, conn, keys, data_iter)``.
Details and a sample callable implementation can be found in the
section :ref:`insert method <io.sql.method>`.
.. versionadded:: 0.24.0
Raises
------
Expand Down Expand Up @@ -2505,7 +2516,7 @@ def to_sql(self, name, con, schema=None, if_exists='fail', index=True,
from pandas.io import sql
sql.to_sql(self, name, con, schema=schema, if_exists=if_exists,
index=index, index_label=index_label, chunksize=chunksize,
dtype=dtype)
dtype=dtype, method=method)

def to_pickle(self, path, compression='infer',
protocol=pkl.HIGHEST_PROTOCOL):
Expand Down
88 changes: 75 additions & 13 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from contextlib import contextmanager
from datetime import date, datetime, time
from functools import partial
import re
import warnings

Expand Down Expand Up @@ -395,7 +396,7 @@ def read_sql(sql, con, index_col=None, coerce_float=True, params=None,


def to_sql(frame, name, con, schema=None, if_exists='fail', index=True,
index_label=None, chunksize=None, dtype=None):
index_label=None, chunksize=None, dtype=None, method=None):
"""
Write records stored in a DataFrame to a SQL database.
Expand Down Expand Up @@ -429,6 +430,17 @@ def to_sql(frame, name, con, schema=None, if_exists='fail', index=True,
Optional specifying the datatype for columns. The SQL type should
be a SQLAlchemy type, or a string for sqlite3 fallback connection.
If all columns are of the same type, one single value can be used.
method : {None, 'multi', callable}, default None
Controls the SQL insertion clause used:
- None : Uses standard SQL ``INSERT`` clause (one per row).
- 'multi': Pass multiple values in a single ``INSERT`` clause.
- callable with signature ``(pd_table, conn, keys, data_iter)``.
Details and a sample callable implementation can be found in the
section :ref:`insert method <io.sql.method>`.
.. versionadded:: 0.24.0
"""
if if_exists not in ('fail', 'replace', 'append'):
raise ValueError("'{0}' is not valid for if_exists".format(if_exists))
Expand All @@ -443,7 +455,7 @@ def to_sql(frame, name, con, schema=None, if_exists='fail', index=True,

pandas_sql.to_sql(frame, name, if_exists=if_exists, index=index,
index_label=index_label, schema=schema,
chunksize=chunksize, dtype=dtype)
chunksize=chunksize, dtype=dtype, method=method)


def has_table(table_name, con, schema=None):
Expand Down Expand Up @@ -568,8 +580,29 @@ def create(self):
else:
self._execute_create()

def insert_statement(self):
return self.table.insert()
def _execute_insert(self, conn, keys, data_iter):
"""Execute SQL statement inserting data
Parameters
----------
conn : sqlalchemy.engine.Engine or sqlalchemy.engine.Connection
keys : list of str
Column names
data_iter : generator of list
Each item contains a list of values to be inserted
"""
data = [dict(zip(keys, row)) for row in data_iter]
conn.execute(self.table.insert(), data)

def _execute_insert_multi(self, conn, keys, data_iter):
"""Alternative to _execute_insert for DBs support multivalue INSERT.
Note: multi-value insert is usually faster for analytics DBs
and tables containing a few columns
but performance degrades quickly with increase of columns.
"""
data = [dict(zip(keys, row)) for row in data_iter]
conn.execute(self.table.insert(data))

def insert_data(self):
if self.index is not None:
Expand Down Expand Up @@ -612,11 +645,18 @@ def insert_data(self):

return column_names, data_list

def _execute_insert(self, conn, keys, data_iter):
data = [dict(zip(keys, row)) for row in data_iter]
conn.execute(self.insert_statement(), data)
def insert(self, chunksize=None, method=None):

# set insert method
if method is None:
exec_insert = self._execute_insert
elif method == 'multi':
exec_insert = self._execute_insert_multi
elif callable(method):
exec_insert = partial(method, self)
else:
raise ValueError('Invalid parameter `method`: {}'.format(method))

def insert(self, chunksize=None):
keys, data_list = self.insert_data()

nrows = len(self.frame)
Expand All @@ -639,7 +679,7 @@ def insert(self, chunksize=None):
break

chunk_iter = zip(*[arr[start_i:end_i] for arr in data_list])
self._execute_insert(conn, keys, chunk_iter)
exec_insert(conn, keys, chunk_iter)

def _query_iterator(self, result, chunksize, columns, coerce_float=True,
parse_dates=None):
Expand Down Expand Up @@ -1085,7 +1125,8 @@ def read_query(self, sql, index_col=None, coerce_float=True,
read_sql = read_query

def to_sql(self, frame, name, if_exists='fail', index=True,
index_label=None, schema=None, chunksize=None, dtype=None):
index_label=None, schema=None, chunksize=None, dtype=None,
method=None):
"""
Write records stored in a DataFrame to a SQL database.
Expand Down Expand Up @@ -1115,7 +1156,17 @@ def to_sql(self, frame, name, if_exists='fail', index=True,
Optional specifying the datatype for columns. The SQL type should
be a SQLAlchemy type. If all columns are of the same type, one
single value can be used.
method : {None', 'multi', callable}, default None
Controls the SQL insertion clause used:
* None : Uses standard SQL ``INSERT`` clause (one per row).
* 'multi': Pass multiple values in a single ``INSERT`` clause.
* callable with signature ``(pd_table, conn, keys, data_iter)``.
Details and a sample callable implementation can be found in the
section :ref:`insert method <io.sql.method>`.
.. versionadded:: 0.24.0
"""
if dtype and not is_dict_like(dtype):
dtype = {col_name: dtype for col_name in frame}
Expand All @@ -1131,7 +1182,7 @@ def to_sql(self, frame, name, if_exists='fail', index=True,
if_exists=if_exists, index_label=index_label,
schema=schema, dtype=dtype)
table.create()
table.insert(chunksize)
table.insert(chunksize, method=method)
if (not name.isdigit() and not name.islower()):
# check for potentially case sensitivity issues (GH7815)
# Only check when name is not a number and name is not lower case
Expand Down Expand Up @@ -1442,7 +1493,8 @@ def _fetchall_as_list(self, cur):
return result

def to_sql(self, frame, name, if_exists='fail', index=True,
index_label=None, schema=None, chunksize=None, dtype=None):
index_label=None, schema=None, chunksize=None, dtype=None,
method=None):
"""
Write records stored in a DataFrame to a SQL database.
Expand Down Expand Up @@ -1471,7 +1523,17 @@ def to_sql(self, frame, name, if_exists='fail', index=True,
Optional specifying the datatype for columns. The SQL type should
be a string. If all columns are of the same type, one single value
can be used.
method : {None, 'multi', callable}, default None
Controls the SQL insertion clause used:
* None : Uses standard SQL ``INSERT`` clause (one per row).
* 'multi': Pass multiple values in a single ``INSERT`` clause.
* callable with signature ``(pd_table, conn, keys, data_iter)``.
Details and a sample callable implementation can be found in the
section :ref:`insert method <io.sql.method>`.
.. versionadded:: 0.24.0
"""
if dtype and not is_dict_like(dtype):
dtype = {col_name: dtype for col_name in frame}
Expand All @@ -1486,7 +1548,7 @@ def to_sql(self, frame, name, if_exists='fail', index=True,
if_exists=if_exists, index_label=index_label,
dtype=dtype)
table.create()
table.insert(chunksize)
table.insert(chunksize, method)

def has_table(self, name, schema=None):
# TODO(wesm): unused?
Expand Down
65 changes: 62 additions & 3 deletions pandas/tests/io/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,12 +375,16 @@ def _read_sql_iris_named_parameter(self):
iris_frame = self.pandasSQL.read_query(query, params=params)
self._check_iris_loaded_frame(iris_frame)

def _to_sql(self):
def _to_sql(self, method=None):
self.drop_table('test_frame1')

self.pandasSQL.to_sql(self.test_frame1, 'test_frame1')
self.pandasSQL.to_sql(self.test_frame1, 'test_frame1', method=method)
assert self.pandasSQL.has_table('test_frame1')

num_entries = len(self.test_frame1)
num_rows = self._count_rows('test_frame1')
assert num_rows == num_entries

# Nuke table
self.drop_table('test_frame1')

Expand Down Expand Up @@ -434,6 +438,25 @@ def _to_sql_append(self):
assert num_rows == num_entries
self.drop_table('test_frame1')

def _to_sql_method_callable(self):
check = [] # used to double check function below is really being used

def sample(pd_table, conn, keys, data_iter):
check.append(1)
data = [dict(zip(keys, row)) for row in data_iter]
conn.execute(pd_table.table.insert(), data)
self.drop_table('test_frame1')

self.pandasSQL.to_sql(self.test_frame1, 'test_frame1', method=sample)
assert self.pandasSQL.has_table('test_frame1')

assert check == [1]
num_entries = len(self.test_frame1)
num_rows = self._count_rows('test_frame1')
assert num_rows == num_entries
# Nuke table
self.drop_table('test_frame1')

def _roundtrip(self):
self.drop_table('test_frame_roundtrip')
self.pandasSQL.to_sql(self.test_frame1, 'test_frame_roundtrip')
Expand Down Expand Up @@ -1193,7 +1216,7 @@ def setup_connect(self):
pytest.skip(
"Can't connect to {0} server".format(self.flavor))

def test_aread_sql(self):
def test_read_sql(self):
self._read_sql_iris()

def test_read_sql_parameter(self):
Expand All @@ -1217,6 +1240,12 @@ def test_to_sql_replace(self):
def test_to_sql_append(self):
self._to_sql_append()

def test_to_sql_method_multi(self):
self._to_sql(method='multi')

def test_to_sql_method_callable(self):
self._to_sql_method_callable()

def test_create_table(self):
temp_conn = self.connect()
temp_frame = DataFrame(
Expand Down Expand Up @@ -1930,6 +1959,36 @@ def test_schema_support(self):
res2 = pdsql.read_table('test_schema_other2')
tm.assert_frame_equal(res1, res2)

def test_copy_from_callable_insertion_method(self):
# GH 8953
# Example in io.rst found under _io.sql.method
# not available in sqlite, mysql
def psql_insert_copy(table, conn, keys, data_iter):
# gets a DBAPI connection that can provide a cursor
dbapi_conn = conn.connection
with dbapi_conn.cursor() as cur:
s_buf = compat.StringIO()
writer = csv.writer(s_buf)
writer.writerows(data_iter)
s_buf.seek(0)

columns = ', '.join('"{}"'.format(k) for k in keys)
if table.schema:
table_name = '{}.{}'.format(table.schema, table.name)
else:
table_name = table.name

sql_query = 'COPY {} ({}) FROM STDIN WITH CSV'.format(
table_name, columns)
cur.copy_expert(sql=sql_query, file=s_buf)

expected = DataFrame({'col1': [1, 2], 'col2': [0.1, 0.2],
'col3': ['a', 'n']})
expected.to_sql('test_copy_insert', self.conn, index=False,
method=psql_insert_copy)
result = sql.read_sql_table('test_copy_insert', self.conn)
tm.assert_frame_equal(result, expected)


@pytest.mark.single
class TestMySQLAlchemy(_TestMySQLAlchemy, _TestSQLAlchemy):
Expand Down

0 comments on commit 31d2736

Please sign in to comment.