Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support passing arbitrary arguments/context to custom extensions (Issue #700) #814

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/syrupy/assertion.py
Expand Up @@ -63,6 +63,7 @@ class SnapshotAssertion:
include: Optional["PropertyFilter"] = None
exclude: Optional["PropertyFilter"] = None
matcher: Optional["PropertyMatcher"] = None
extra_args: Dict = field(default_factory=dict)

_exclude: Optional["PropertyFilter"] = field(
init=False,
Expand Down Expand Up @@ -105,6 +106,7 @@ def __post_init__(self) -> None:
self._include = self.include
self._exclude = self.exclude
self._matcher = self.matcher
self._extra_args = self.extra_args

def __init_extension(
self, extension_class: Type["AbstractSyrupyExtension"]
Expand Down Expand Up @@ -178,6 +180,7 @@ def with_defaults(
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
extension_class: Optional[Type["AbstractSyrupyExtension"]] = None,
extra_args: Optional[Dict] = None
) -> "SnapshotAssertion":
"""
Create new snapshot assertion fixture with provided values. This preserves
Expand All @@ -191,6 +194,7 @@ def with_defaults(
test_location=self.test_location,
extension_class=extension_class or self.extension_class,
session=self.session,
extra_args=extra_args or self.extra_args
)

def use_extension(
Expand Down Expand Up @@ -264,6 +268,7 @@ def __call__(
extension_class: Optional[Type["AbstractSyrupyExtension"]] = None,
matcher: Optional["PropertyMatcher"] = None,
name: Optional["SnapshotIndex"] = None,
extra_args: Optional[Dict] = None,
) -> "SnapshotAssertion":
"""
Modifies assertion instance options
Expand All @@ -280,6 +285,8 @@ def __call__(
self.__with_prop("_custom_index", name)
if diff is not None:
self.__with_prop("_snapshot_diff", diff)
if extra_args:
self._extra_args = extra_args
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use self.__with_prop. It does seam auto-cleanup

return self

def __repr__(self) -> str:
Expand All @@ -300,6 +307,11 @@ def _assert(self, data: "SerializableData") -> bool:
matches = False
assertion_success = False
assertion_exception = None
matcher_options = None
for key,value in self._extra_args.items():
if key == "matcher_options":
matcher_options = value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This enforces a certain schema in the extra_args object which if we want to support arbitrary args we shouldn't do. Instead, we should propagate the "extra_args" to all relevant methods.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this. I also forwaded extra_args to different functions. Could you please tell me what you think about this?

print("matcher_options", matcher_options)
atharva-2001 marked this conversation as resolved.
Show resolved Hide resolved
try:
snapshot_data, tainted = self._recall_data(index=self.index)
serialized_data = self._serialize(data)
Expand All @@ -316,7 +328,7 @@ def _assert(self, data: "SerializableData") -> bool:
not tainted
and snapshot_data is not None
and self.extension.matches(
serialized_data=serialized_data, snapshot_data=snapshot_data
serialized_data=serialized_data, snapshot_data=snapshot_data, **matcher_options
)
)
assertion_success = matches
Expand Down
6 changes: 4 additions & 2 deletions src/syrupy/extensions/amber/__init__.py
Expand Up @@ -47,7 +47,9 @@ def delete_snapshots(
else:
Path(snapshot_location).unlink()

def _read_snapshot_collection(self, snapshot_location: str) -> "SnapshotCollection":
def _read_snapshot_collection(
self, snapshot_location: str, **kwargs: Any
) -> "SnapshotCollection":
return self.serializer_class.read_file(snapshot_location)

@classmethod
Expand All @@ -72,7 +74,7 @@ def _read_snapshot_data_from_location(

@classmethod
def _write_snapshot_collection(
cls, *, snapshot_collection: "SnapshotCollection"
cls, *, snapshot_collection: "SnapshotCollection", **kwargs: Any
) -> None:
cls.serializer_class.write_file(snapshot_collection, merge=True)

Expand Down
21 changes: 15 additions & 6 deletions src/syrupy/extensions/base.py
Expand Up @@ -8,6 +8,7 @@
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterator,
Expand Down Expand Up @@ -67,6 +68,7 @@ def serialize(
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
**kwargs: Any,
) -> "SerializedData":
"""
Serializes a python object / data structure into a string
Expand Down Expand Up @@ -108,7 +110,7 @@ def is_snapshot_location(self, *, location: str) -> bool:
return location.endswith(self._file_extension)

def discover_snapshots(
self, *, test_location: "PyTestLocation"
self, *, test_location: "PyTestLocation", **kwargs: Any
) -> "SnapshotCollections":
"""
Returns all snapshot collections in test site
Expand Down Expand Up @@ -216,7 +218,7 @@ def delete_snapshots(

@abstractmethod
def _read_snapshot_collection(
self, *, snapshot_location: str
self, *, snapshot_location: str, **kwargs: Any
) -> "SnapshotCollection":
"""
Read the snapshot location and construct a snapshot collection object
Expand All @@ -235,15 +237,15 @@ def _read_snapshot_data_from_location(
@classmethod
@abstractmethod
def _write_snapshot_collection(
cls, *, snapshot_collection: "SnapshotCollection"
cls, *, snapshot_collection: "SnapshotCollection", **kwargs: Any
) -> None:
"""
Adds the snapshot data to the snapshots in collection location
"""
raise NotImplementedError

@classmethod
def dirname(cls, *, test_location: "PyTestLocation") -> str:
def dirname(cls, *, test_location: "PyTestLocation", **kwargs: Any) -> str:
test_dir = Path(test_location.filepath).parent
return str(test_dir.joinpath(SNAPSHOT_DIRNAME))

Expand All @@ -259,15 +261,21 @@ class SnapshotReporter(ABC):
_context_line_count = 1

def diff_snapshots(
self, serialized_data: "SerializedData", snapshot_data: "SerializedData"
self,
serialized_data: "SerializedData",
snapshot_data: "SerializedData",
**kwargs: Any,
) -> "SerializedData":
env = {DISABLE_COLOR_ENV_VAR: "true"}
attrs = {"_context_line_count": 0}
with env_context(**env), obj_attrs(self, attrs):
return "\n".join(self.diff_lines(serialized_data, snapshot_data))

def diff_lines(
self, serialized_data: "SerializedData", snapshot_data: "SerializedData"
self,
serialized_data: "SerializedData",
snapshot_data: "SerializedData",
**kwargs: Any,
) -> Iterator[str]:
for line in self.__diff_lines(str(snapshot_data), str(serialized_data)):
yield reset(line)
Expand Down Expand Up @@ -407,6 +415,7 @@ def matches(
*,
serialized_data: "SerializableData",
snapshot_data: "SerializableData",
**kwargs: Any,
) -> bool:
"""
Compares serialized data and snapshot data and returns
Expand Down
1 change: 1 addition & 0 deletions src/syrupy/extensions/json/__init__.py
Expand Up @@ -145,6 +145,7 @@ def serialize(
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
**kwargs: Any,
) -> "SerializedData":
data = self._filter(
data=data,
Expand Down
14 changes: 11 additions & 3 deletions src/syrupy/extensions/single_file.py
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Optional,
Set,
Type,
Expand Down Expand Up @@ -49,6 +50,7 @@ def serialize(
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
**kwargs: Any,
) -> "SerializedData":
return self.get_supported_dataclass()(data)

Expand All @@ -74,12 +76,15 @@ def _get_file_basename(
return cls.get_snapshot_name(test_location=test_location, index=index)

@classmethod
def dirname(cls, *, test_location: "PyTestLocation") -> str:
def dirname(cls, *, test_location: "PyTestLocation", **kwargs: Any) -> str:
original_dirname = AbstractSyrupyExtension.dirname(test_location=test_location)
return str(Path(original_dirname).joinpath(test_location.basename))

def _read_snapshot_collection(
self, *, snapshot_location: str
self,
*,
snapshot_location: str,
**kwargs: Any,
) -> "SnapshotCollection":
file_ext_len = len(self._file_extension) + 1 if self._file_extension else 0
filename_wo_ext = snapshot_location[:-file_ext_len]
Expand Down Expand Up @@ -116,7 +121,10 @@ def get_write_encoding(cls) -> Optional[str]:

@classmethod
def _write_snapshot_collection(
cls, *, snapshot_collection: "SnapshotCollection"
cls,
*,
snapshot_collection: "SnapshotCollection",
**kwargs: Any,
) -> None:
filepath, data = (
snapshot_collection.location,
Expand Down