Skip to content

Commit

Permalink
Address palewire#91.
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffrey-eisenbarth committed Mar 5, 2023
1 parent 9159b12 commit 4376299
Showing 1 changed file with 45 additions and 9 deletions.
54 changes: 45 additions & 9 deletions postgres_copy/copy_from.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
from collections import OrderedDict
from io import TextIOWrapper
import warnings
from django.db import NotSupportedError
from django.db import connections, router
from django.core.exceptions import FieldDoesNotExist
Expand All @@ -33,6 +34,7 @@ def __init__(
force_null=None,
encoding=None,
ignore_conflicts=False,
on_conflict={},
static_mapping=None,
temp_table_name=None
):
Expand All @@ -57,8 +59,9 @@ def __init__(
self.force_not_null = force_not_null
self.force_null = force_null
self.encoding = encoding
self.supports_ignore_conflicts = True
self.supports_on_conflict = True
self.ignore_conflicts = ignore_conflicts
self.on_conflict = on_conflict
if static_mapping is not None:
self.static_mapping = OrderedDict(static_mapping)
else:
Expand All @@ -76,10 +79,18 @@ def __init__(
if self.conn.vendor != 'postgresql':
raise TypeError("Only PostgreSQL backends supported")

# Check if it is PSQL 9.5 or greater, which determines if ignore_conflicts is supported
self.supports_ignore_conflicts = self.is_postgresql_9_5()
if self.ignore_conflicts and not self.supports_ignore_conflicts:
raise NotSupportedError('This database backend does not support ignoring conflicts.')
# Check if it is PSQL 9.5 or greater, which determines if on_conflict is supported
self.supports_on_conflict = self.is_postgresql_9_5()
if self.ignore_conflicts:
self.on_conflict = {
'action': 'ignore',
}
warnings.warn(
"The `ignore_conflicts` kwarg has been replaced with "
"on_conflict={'action': 'ignore'}."
)
if self.on_conflict and not self.supports_on_conflict:
raise NotSupportedError('This database backend does not support conflict logic.')

# Pull the CSV headers
self.headers = self.get_headers()
Expand Down Expand Up @@ -317,10 +328,35 @@ def insert_suffix(self):
"""
Preps the suffix to the insert query.
"""
if self.ignore_conflicts:
return """
ON CONFLICT DO NOTHING;
"""
if self.on_conflict:
try:
action = self.on_conflict['action']
except KeyError:
raise ValueError("Must specify an `action` when passing `on_conflict`.")
if action is None:
target, action = "", "DO NOTHING"
elif action == 'update':
try:
target = self.on_conflict['target']
except KeyError:
raise ValueError("Must specify `target` when action == 'update'.")
if target in [f.name for f in self.model._meta.fields]:
target = "({0})".format(target)
elif target in [c.name for c in self.model._meta.constraints]:
target = "ON CONSTRAINT {0}".format(target)
else:
raise ValueError("`target` must be a field name or constraint name.")

if 'columns' in self.on_conflict:
columns = ', '.join([
"{0} = excluded.{0}".format(col)
for col in self.on_conflict['columns']
])
else:
raise ValueError("Must specify `columns` when action == 'update'.")

action = "DO UPDATE SET {0}".format(columns)
return "ON CONFLICT {0} {1};".format(target, action)
else:
return ";"

Expand Down

0 comments on commit 4376299

Please sign in to comment.