Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upsert #156

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
12 changes: 11 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,17 @@ Keyword Argument Description
parameters.

``ignore_conflicts`` Specify True to ignore unique constraint or exclusion
constraint violation errors. The default is False.
constraint violation errors. The default is False. This
is depreciated in favor of `on_conflict={'action': 'ignore'}`.

``on_conflict`` Specifies how PostgreSQL handles conflicts. For example,
`on_conflict={'action': 'ignore'}` will ignore any
conflicts. If setting `'action'` to `'update'`, you
must also specify `'target'` (the source of the
constraint: either a model field name, a constraint name,
or a list of model field names) as well as `'columns'`
(a list of model fields to update). The default is None,
which will raise conflict errors if they occur.

``using`` Sets the database to use when importing data.
Default is None, which will use the ``'default'``
Expand Down
87 changes: 79 additions & 8 deletions postgres_copy/copy_from.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
from collections import OrderedDict
from io import TextIOWrapper
import warnings
from django.db import NotSupportedError
from django.db import connections, router
from django.core.exceptions import FieldDoesNotExist
Expand All @@ -33,6 +34,7 @@ def __init__(
force_null=None,
encoding=None,
ignore_conflicts=False,
on_conflict={},
static_mapping=None,
temp_table_name=None
):
Expand All @@ -57,8 +59,9 @@ def __init__(
self.force_not_null = force_not_null
self.force_null = force_null
self.encoding = encoding
self.supports_ignore_conflicts = True
self.supports_on_conflict = True
self.ignore_conflicts = ignore_conflicts
self.on_conflict = on_conflict
if static_mapping is not None:
self.static_mapping = OrderedDict(static_mapping)
else:
Expand All @@ -76,10 +79,18 @@ def __init__(
if self.conn.vendor != 'postgresql':
raise TypeError("Only PostgreSQL backends supported")

# Check if it is PSQL 9.5 or greater, which determines if ignore_conflicts is supported
self.supports_ignore_conflicts = self.is_postgresql_9_5()
if self.ignore_conflicts and not self.supports_ignore_conflicts:
raise NotSupportedError('This database backend does not support ignoring conflicts.')
# Check if it is PSQL 9.5 or greater, which determines if on_conflict is supported
self.supports_on_conflict = self.is_postgresql_9_5()
if self.ignore_conflicts:
self.on_conflict = {
'action': 'ignore',
}
warnings.warn(
"The `ignore_conflicts` kwarg has been replaced with "
"on_conflict={'action': 'ignore'}."
)
if self.on_conflict and not self.supports_on_conflict:
raise NotSupportedError('This database backend does not support conflict logic.')

# Pull the CSV headers
self.headers = self.get_headers()
Expand Down Expand Up @@ -317,10 +328,70 @@ def insert_suffix(self):
"""
Preps the suffix to the insert query.
"""
if self.ignore_conflicts:
if self.on_conflict:
try:
action = self.on_conflict['action']
except KeyError:
raise ValueError("Must specify an `action` when passing `on_conflict`.")
if action == 'ignore':
target, action = "", "DO NOTHING"
elif action == 'update':
try:
target = self.on_conflict['target']
except KeyError:
raise ValueError("Must specify `target` when action == 'update'.")
try:
columns = self.on_conflict['columns']
except KeyError:
raise ValueError("Must specify `columns` when action == 'update'.")

# As recommended in PostgreSQL's INSERT documentation, we use "index inference"
# rather than naming a constraint directly. Currently, if an `include` param
# is provided to a django.models.Constraint, Django creates a UNIQUE INDEX instead
# of a CONSTRAINT, another reason to use "index inference" by just specifying columns.
constraints = {c.name: c for c in self.model._meta.constraints}
if isinstance(target, str):
if constraint := constraints.get(target):
# Make sure to use db column names
target = [
self.get_field(field_name).column
for field_name in constraint.fields
]
else:
target = [target]
elif not isinstance(target, list):
raise ValueError("`target` must be a string or a list.")
target = "({0})".format(', '.join(target))

# Convert to db_column names
db_columns = [self.model._meta.get_field(col).column for col in columns]

# Get update_values from the `excluded` table
update_values = ', '.join([
"{0} = excluded.{0}".format(db_col)
for db_col in db_columns
])

# Only update the row if the values are different
model_table = self.model._meta.db_table
new_values = ', '.join([
model_table + '.' + db_col
for db_col in db_columns
])
old_values = ', '.join([
"excluded.{0}".format(db_col)
for db_col in db_columns
])
action = "DO UPDATE SET {0} WHERE ({1}) IS DISTINCT FROM ({2})".format(
update_values,
new_values,
old_values,
)
else:
raise ValueError("Action must be one of 'ignore' or 'update'.")
return """
ON CONFLICT DO NOTHING;
"""
ON CONFLICT {0} {1};
""".format(target, action)
else:
return ";"

Expand Down
45 changes: 39 additions & 6 deletions postgres_copy/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,32 +57,44 @@ def drop_constraints(self):

# Remove any field constraints
for field in self.constrained_fields:
logger.debug("Dropping constraints from {}".format(field))
logger.debug("Dropping field constraint from {}".format(field))
field_copy = field.__copy__()
field_copy.db_constraint = False
args = (self.model, field, field_copy)
self.edit_schema(schema_editor, 'alter_field', args)

# Remove remaining constraints
for constraint in getattr(self.model._meta, 'constraints', []):
logger.debug("Dropping constraint '{}'".format(constraint.name))
args = (self.model, constraint)
self.edit_schema(schema_editor, 'remove_constraint', args)

def drop_indexes(self):
"""
Drop indexes on the model and its fields.
"""
logger.debug("Dropping indexes from {}".format(self.model.__name__))
with connection.schema_editor() as schema_editor:
# Remove any "index_together" constraints
logger.debug("Dropping index_together of {}".format(self.model._meta.index_together))
if self.model._meta.index_together:
logger.debug("Dropping index_together of {}".format(self.model._meta.index_together))
args = (self.model, self.model._meta.index_together, ())
self.edit_schema(schema_editor, 'alter_index_together', args)

# Remove any field indexes
for field in self.indexed_fields:
logger.debug("Dropping index from {}".format(field))
logger.debug("Dropping field index from {}".format(field))
field_copy = field.__copy__()
field_copy.db_index = False
args = (self.model, field, field_copy)
self.edit_schema(schema_editor, 'alter_field', args)

# Remove remaining indexes
for index in getattr(self.model._meta, 'indexes', []):
logger.debug("Dropping index '{}'".format(index.name))
args = (self.model, index)
self.edit_schema(schema_editor, 'remove_index', args)

def restore_constraints(self):
"""
Restore constraints on the model and its fields.
Expand All @@ -95,14 +107,20 @@ def restore_constraints(self):
args = (self.model, (), self.model._meta.unique_together)
self.edit_schema(schema_editor, 'alter_unique_together', args)

# Add any constraints to the fields
# Add any field constraints
for field in self.constrained_fields:
logger.debug("Adding constraints to {}".format(field))
logger.debug("Adding field constraint to {}".format(field))
field_copy = field.__copy__()
field_copy.db_constraint = False
args = (self.model, field_copy, field)
self.edit_schema(schema_editor, 'alter_field', args)

# Add remaining constraints
for constraint in getattr(self.model._meta, 'constraints', []):
logger.debug("Adding constraint '{}'".format(constraint.name))
args = (self.model, constraint)
self.edit_schema(schema_editor, 'add_constraint', args)

def restore_indexes(self):
"""
Restore indexes on the model and its fields.
Expand All @@ -117,12 +135,18 @@ def restore_indexes(self):

# Add any indexes to the fields
for field in self.indexed_fields:
logger.debug("Restoring index to {}".format(field))
logger.debug("Restoring field index to {}".format(field))
field_copy = field.__copy__()
field_copy.db_index = False
args = (self.model, field_copy, field)
self.edit_schema(schema_editor, 'alter_field', args)

# Add remaining indexes
for index in getattr(self.model._meta, 'indexes', []):
logger.debug("Adding index '{}'".format(index.name))
args = (self.model, index)
self.edit_schema(schema_editor, 'add_index', args)


class CopyQuerySet(ConstraintQuerySet):
"""
Expand All @@ -146,6 +170,15 @@ def from_csv(self, csv_path, mapping=None, drop_constraints=True, drop_indexes=T
"anyway. Either remove the transaction block, or set "
"drop_constraints=False and drop_indexes=False.")

# NOTE: See GH Issue #117
# We could remove this block if drop_constraints' default was False
if on_conflict := kwargs.get('on_conflict'):
if target := on_conflict.get('target'):
if target in [c.name for c in self.model._meta.constraints]:
drop_constraints = False
elif on_conflict.get('action') == 'ignore':
drop_constraints = False

mapping = CopyMapping(self.model, csv_path, mapping, **kwargs)

if drop_constraints:
Expand Down
31 changes: 30 additions & 1 deletion tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,35 @@ class SecondaryMockObject(models.Model):
objects = CopyManager()


class UniqueMockObject(models.Model):
class UniqueFieldConstraintMockObject(models.Model):
name = models.CharField(max_length=500, unique=True)
objects = CopyManager()


class UniqueModelConstraintMockObject(models.Model):
name = models.CharField(max_length=500)
number = MyIntegerField(null=True, db_column='num')
objects = CopyManager()

class Meta:
constraints = [
models.UniqueConstraint(
name='constraint',
fields=['name'],
),
]


class UniqueModelConstraintAsIndexMockObject(models.Model):
name = models.CharField(max_length=500)
number = MyIntegerField(null=True, db_column='num')
objects = CopyManager()

class Meta:
constraints = [
models.UniqueConstraint(
name='constraint_as_index',
fields=['name'],
include=['number'], # Converts Constraint to Index
),
]
84 changes: 81 additions & 3 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
OverloadMockObject,
HookedCopyMapping,
SecondaryMockObject,
UniqueMockObject
UniqueFieldConstraintMockObject,
UniqueModelConstraintMockObject,
UniqueModelConstraintAsIndexMockObject,
)
from django.test import TestCase
from django.db import transaction
Expand Down Expand Up @@ -589,17 +591,93 @@ def test_encoding_save(self, _):

@mock.patch("django.db.connection.validate_no_atomic_block")
def test_ignore_conflicts(self, _):
UniqueMockObject.objects.from_csv(
UniqueFieldConstraintMockObject.objects.from_csv(
self.name_path,
dict(name='NAME'),
ignore_conflicts=True
)
UniqueMockObject.objects.from_csv(
UniqueFieldConstraintMockObject.objects.from_csv(
self.name_path,
dict(name='NAME'),
ignore_conflicts=True
)

@mock.patch("django.db.connection.validate_no_atomic_block")
def test_on_conflict_ignore(self, _):
UniqueModelConstraintMockObject.objects.from_csv(
self.name_path,
dict(name='NAME', number='NUMBER'),
on_conflict={'action': 'ignore'},
)
UniqueModelConstraintMockObject.objects.from_csv(
self.name_path,
dict(name='NAME', number='NUMBER'),
on_conflict={'action': 'ignore'},
)

@mock.patch("django.db.connection.validate_no_atomic_block")
def test_on_conflict_target_field_update(self, _):
UniqueFieldConstraintMockObject.objects.from_csv(
self.name_path,
dict(name='NAME'),
on_conflict={
'action': 'update',
'target': 'name',
'columns': ['name'],
},
)
UniqueFieldConstraintMockObject.objects.from_csv(
self.name_path,
dict(name='NAME'),
on_conflict={
'action': 'update',
'target': 'name',
'columns': ['name'],
},
)

@mock.patch("django.db.connection.validate_no_atomic_block")
def test_on_conflict_target_constraint_update(self, _):
UniqueModelConstraintMockObject.objects.from_csv(
self.name_path,
dict(name='NAME', number='NUMBER'),
on_conflict={
'action': 'update',
'target': 'constraint',
'columns': ['name', 'number'],
},
)
UniqueModelConstraintMockObject.objects.from_csv(
self.name_path,
dict(name='NAME', number='NUMBER'),
on_conflict={
'action': 'update',
'target': 'constraint',
'columns': ['name', 'number'],
},
)

@mock.patch("django.db.connection.validate_no_atomic_block")
def test_on_conflict_target_constraint_as_index_update(self, _):
UniqueModelConstraintAsIndexMockObject.objects.from_csv(
self.name_path,
dict(name='NAME', number='NUMBER'),
on_conflict={
'action': 'update',
'target': 'constraint_as_index',
'columns': ['name', 'number'],
},
)
UniqueModelConstraintAsIndexMockObject.objects.from_csv(
self.name_path,
dict(name='NAME', number='NUMBER'),
on_conflict={
'action': 'update',
'target': 'constraint_as_index',
'columns': ['name', 'number'],
},
)

@mock.patch("django.db.connection.validate_no_atomic_block")
def test_static_values(self, _):
ExtendedMockObject.objects.from_csv(
Expand Down