Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support passing arbitrary arguments/context to custom extensions (Issue #700) #814

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 24 additions & 6 deletions src/syrupy/assertion.py
Expand Up @@ -63,6 +63,7 @@ class SnapshotAssertion:
include: Optional["PropertyFilter"] = None
exclude: Optional["PropertyFilter"] = None
matcher: Optional["PropertyMatcher"] = None
extra_args: Dict = field(default_factory=dict)

_exclude: Optional["PropertyFilter"] = field(
init=False,
Expand Down Expand Up @@ -105,6 +106,7 @@ def __post_init__(self) -> None:
self._include = self.include
self._exclude = self.exclude
self._matcher = self.matcher
self._extra_args = self.extra_args

def __init_extension(
self, extension_class: Type["AbstractSyrupyExtension"]
Expand Down Expand Up @@ -178,6 +180,7 @@ def with_defaults(
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
extension_class: Optional[Type["AbstractSyrupyExtension"]] = None,
extra_args: Optional[Dict] = None,
) -> "SnapshotAssertion":
"""
Create new snapshot assertion fixture with provided values. This preserves
Expand All @@ -191,6 +194,7 @@ def with_defaults(
test_location=self.test_location,
extension_class=extension_class or self.extension_class,
session=self.session,
extra_args=extra_args or self.extra_args,
)

def use_extension(
Expand All @@ -205,9 +209,13 @@ def use_extension(
def assert_match(self, data: "SerializableData") -> None:
assert self == data

def _serialize(self, data: "SerializableData") -> "SerializedData":
def _serialize(self, data: "SerializableData", **kwargs: Any) -> "SerializedData":
return self.extension.serialize(
data, exclude=self._exclude, include=self._include, matcher=self.__matcher
data,
exclude=self._exclude,
include=self._include,
matcher=self.__matcher,
**kwargs,
)

def get_assert_diff(self) -> List[str]:
Expand Down Expand Up @@ -264,6 +272,7 @@ def __call__(
extension_class: Optional[Type["AbstractSyrupyExtension"]] = None,
matcher: Optional["PropertyMatcher"] = None,
name: Optional["SnapshotIndex"] = None,
extra_args: Optional[Dict] = None,
) -> "SnapshotAssertion":
"""
Modifies assertion instance options
Expand All @@ -280,6 +289,8 @@ def __call__(
self.__with_prop("_custom_index", name)
if diff is not None:
self.__with_prop("_snapshot_diff", diff)
if extra_args:
self.__with_prop("_extra_args", extra_args)
return self

def __repr__(self) -> str:
Expand All @@ -300,23 +311,29 @@ def _assert(self, data: "SerializableData") -> bool:
matches = False
assertion_success = False
assertion_exception = None
extra_args = getattr(self, "_extra_args", {})
try:
snapshot_data, tainted = self._recall_data(index=self.index)
serialized_data = self._serialize(data)
serialized_data = self._serialize(data, **extra_args)
snapshot_diff = getattr(self, "_snapshot_diff", None)
if snapshot_diff is not None:
snapshot_data_diff, _ = self._recall_data(index=snapshot_diff)
snapshot_data_diff, _ = self._recall_data(
index=snapshot_diff, **extra_args
)
if snapshot_data_diff is None:
raise SnapshotDoesNotExist()
serialized_data = self.extension.diff_snapshots(
serialized_data=serialized_data,
snapshot_data=snapshot_data_diff,
**extra_args,
)
matches = (
not tainted
and snapshot_data is not None
and self.extension.matches(
serialized_data=serialized_data, snapshot_data=snapshot_data
serialized_data=serialized_data,
snapshot_data=snapshot_data,
**extra_args,
)
)
assertion_success = matches
Expand Down Expand Up @@ -361,14 +378,15 @@ def _post_assert(self) -> None:
self._post_assert_actions.pop()()

def _recall_data(
self, index: "SnapshotIndex"
self, index: "SnapshotIndex", **kwargs: Any
) -> Tuple[Optional["SerializableData"], bool]:
try:
return (
self.extension.read_snapshot(
test_location=self.test_location,
index=index,
session_id=str(id(self.session)),
**kwargs,
),
False,
)
Expand Down
6 changes: 4 additions & 2 deletions src/syrupy/extensions/amber/__init__.py
Expand Up @@ -47,7 +47,9 @@ def delete_snapshots(
else:
Path(snapshot_location).unlink()

def _read_snapshot_collection(self, snapshot_location: str) -> "SnapshotCollection":
def _read_snapshot_collection(
self, snapshot_location: str, **kwargs: Any
) -> "SnapshotCollection":
return self.serializer_class.read_file(snapshot_location)

@classmethod
Expand All @@ -72,7 +74,7 @@ def _read_snapshot_data_from_location(

@classmethod
def _write_snapshot_collection(
cls, *, snapshot_collection: "SnapshotCollection"
cls, *, snapshot_collection: "SnapshotCollection", **kwargs: Any
) -> None:
cls.serializer_class.write_file(snapshot_collection, merge=True)

Expand Down
32 changes: 24 additions & 8 deletions src/syrupy/extensions/base.py
Expand Up @@ -8,6 +8,7 @@
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterator,
Expand Down Expand Up @@ -67,6 +68,7 @@ def serialize(
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
**kwargs: Any,
) -> "SerializedData":
"""
Serializes a python object / data structure into a string
Expand Down Expand Up @@ -108,7 +110,7 @@ def is_snapshot_location(self, *, location: str) -> bool:
return location.endswith(self._file_extension)

def discover_snapshots(
self, *, test_location: "PyTestLocation"
self, *, test_location: "PyTestLocation", **kwargs: Any
) -> "SnapshotCollections":
"""
Returns all snapshot collections in test site
Expand All @@ -117,7 +119,7 @@ def discover_snapshots(
for filepath in walk_snapshot_dir(self.dirname(test_location=test_location)):
if self.is_snapshot_location(location=filepath):
snapshot_collection = self._read_snapshot_collection(
snapshot_location=filepath
snapshot_location=filepath, **kwargs
)
if not snapshot_collection.has_snapshots:
snapshot_collection = SnapshotEmptyCollection(location=filepath)
Expand All @@ -134,6 +136,7 @@ def read_snapshot(
test_location: "PyTestLocation",
index: "SnapshotIndex",
session_id: str,
**kwargs: Any,
) -> "SerializedData":
"""
This method is _final_, do not override. You can override
Expand All @@ -145,6 +148,7 @@ def read_snapshot(
snapshot_location=snapshot_location,
snapshot_name=snapshot_name,
session_id=session_id,
**kwargs,
)
if snapshot_data is None:
raise SnapshotDoesNotExist()
Expand Down Expand Up @@ -216,7 +220,7 @@ def delete_snapshots(

@abstractmethod
def _read_snapshot_collection(
self, *, snapshot_location: str
self, *, snapshot_location: str, **kwargs: Any
) -> "SnapshotCollection":
"""
Read the snapshot location and construct a snapshot collection object
Expand All @@ -225,7 +229,12 @@ def _read_snapshot_collection(

@abstractmethod
def _read_snapshot_data_from_location(
self, *, snapshot_location: str, snapshot_name: str, session_id: str
self,
*,
snapshot_location: str,
snapshot_name: str,
session_id: str,
**kwargs: Any,
) -> Optional["SerializedData"]:
"""
Get only the snapshot data from location for assertion
Expand All @@ -235,15 +244,15 @@ def _read_snapshot_data_from_location(
@classmethod
@abstractmethod
def _write_snapshot_collection(
cls, *, snapshot_collection: "SnapshotCollection"
cls, *, snapshot_collection: "SnapshotCollection", **kwargs: Any
) -> None:
"""
Adds the snapshot data to the snapshots in collection location
"""
raise NotImplementedError

@classmethod
def dirname(cls, *, test_location: "PyTestLocation") -> str:
def dirname(cls, *, test_location: "PyTestLocation", **kwargs: Any) -> str:
test_dir = Path(test_location.filepath).parent
return str(test_dir.joinpath(SNAPSHOT_DIRNAME))

Expand All @@ -259,15 +268,21 @@ class SnapshotReporter(ABC):
_context_line_count = 1

def diff_snapshots(
self, serialized_data: "SerializedData", snapshot_data: "SerializedData"
self,
serialized_data: "SerializedData",
snapshot_data: "SerializedData",
**kwargs: Any,
) -> "SerializedData":
env = {DISABLE_COLOR_ENV_VAR: "true"}
attrs = {"_context_line_count": 0}
with env_context(**env), obj_attrs(self, attrs):
return "\n".join(self.diff_lines(serialized_data, snapshot_data))

def diff_lines(
self, serialized_data: "SerializedData", snapshot_data: "SerializedData"
self,
serialized_data: "SerializedData",
snapshot_data: "SerializedData",
**kwargs: Any,
) -> Iterator[str]:
for line in self.__diff_lines(str(snapshot_data), str(serialized_data)):
yield reset(line)
Expand Down Expand Up @@ -407,6 +422,7 @@ def matches(
*,
serialized_data: "SerializableData",
snapshot_data: "SerializableData",
**kwargs: Any,
) -> bool:
"""
Compares serialized data and snapshot data and returns
Expand Down
1 change: 1 addition & 0 deletions src/syrupy/extensions/json/__init__.py
Expand Up @@ -145,6 +145,7 @@ def serialize(
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
**kwargs: Any,
) -> "SerializedData":
data = self._filter(
data=data,
Expand Down
14 changes: 11 additions & 3 deletions src/syrupy/extensions/single_file.py
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Optional,
Set,
Type,
Expand Down Expand Up @@ -49,6 +50,7 @@ def serialize(
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
**kwargs: Any,
) -> "SerializedData":
return self.get_supported_dataclass()(data)

Expand All @@ -74,12 +76,15 @@ def _get_file_basename(
return cls.get_snapshot_name(test_location=test_location, index=index)

@classmethod
def dirname(cls, *, test_location: "PyTestLocation") -> str:
def dirname(cls, *, test_location: "PyTestLocation", **kwargs: Any) -> str:
original_dirname = AbstractSyrupyExtension.dirname(test_location=test_location)
return str(Path(original_dirname).joinpath(test_location.basename))

def _read_snapshot_collection(
self, *, snapshot_location: str
self,
*,
snapshot_location: str,
**kwargs: Any,
) -> "SnapshotCollection":
file_ext_len = len(self._file_extension) + 1 if self._file_extension else 0
filename_wo_ext = snapshot_location[:-file_ext_len]
Expand Down Expand Up @@ -116,7 +121,10 @@ def get_write_encoding(cls) -> Optional[str]:

@classmethod
def _write_snapshot_collection(
cls, *, snapshot_collection: "SnapshotCollection"
cls,
*,
snapshot_collection: "SnapshotCollection",
**kwargs: Any,
) -> None:
filepath, data = (
snapshot_collection.location,
Expand Down