Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: scaffolding to support custom context in extensions #816

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 24 additions & 7 deletions src/syrupy/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ class SnapshotAssertion:
exclude: Optional["PropertyFilter"] = None
matcher: Optional["PropertyMatcher"] = None

# context is reserved exclusively for custom extensions
context: Optional[Dict[str, Any]] = None

_exclude: Optional["PropertyFilter"] = field(
init=False,
default=None,
Expand Down Expand Up @@ -109,7 +112,8 @@ def __post_init__(self) -> None:
def __init_extension(
self, extension_class: Type["AbstractSyrupyExtension"]
) -> "AbstractSyrupyExtension":
return extension_class()
kwargs = {"context": self.context} if self.context else {}
return extension_class(**kwargs)

@property
def extension(self) -> "AbstractSyrupyExtension":
Expand Down Expand Up @@ -178,6 +182,7 @@ def with_defaults(
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
extension_class: Optional[Type["AbstractSyrupyExtension"]] = None,
context: Optional[Dict[str, Any]] = None,
) -> "SnapshotAssertion":
"""
Create new snapshot assertion fixture with provided values. This preserves
Expand All @@ -191,6 +196,7 @@ def with_defaults(
test_location=self.test_location,
extension_class=extension_class or self.extension_class,
session=self.session,
context=context or self.context,
)

def use_extension(
Expand All @@ -207,7 +213,10 @@ def assert_match(self, data: "SerializableData") -> None:

def _serialize(self, data: "SerializableData") -> "SerializedData":
return self.extension.serialize(
data, exclude=self._exclude, include=self._include, matcher=self.__matcher
data,
exclude=self._exclude,
include=self._include,
matcher=self.__matcher,
)

def get_assert_diff(self) -> List[str]:
Expand Down Expand Up @@ -264,6 +273,7 @@ def __call__(
extension_class: Optional[Type["AbstractSyrupyExtension"]] = None,
matcher: Optional["PropertyMatcher"] = None,
name: Optional["SnapshotIndex"] = None,
context: Optional[Dict[str, Any]] = None,
) -> "SnapshotAssertion":
"""
Modifies assertion instance options
Expand All @@ -272,14 +282,18 @@ def __call__(
self.__with_prop("_exclude", exclude)
if include:
self.__with_prop("_include", include)
if extension_class:
self.__with_prop("_extension", self.__init_extension(extension_class))
if matcher:
self.__with_prop("_matcher", matcher)
if name:
self.__with_prop("_custom_index", name)
if diff is not None:
self.__with_prop("_snapshot_diff", diff)
if context and context != self.context:
self.__with_prop("context", context)
# We need to force the extension to be re-initialized if the context changes
extension_class = extension_class or self.extension_class
if extension_class:
self.__with_prop("_extension", self.__init_extension(extension_class))
return self

def __repr__(self) -> str:
Expand All @@ -290,10 +304,12 @@ def __eq__(self, other: "SerializableData") -> bool:

def _assert(self, data: "SerializableData") -> bool:
snapshot_location = self.extension.get_location(
test_location=self.test_location, index=self.index
test_location=self.test_location,
index=self.index,
)
snapshot_name = self.extension.get_snapshot_name(
test_location=self.test_location, index=self.index
test_location=self.test_location,
index=self.index,
)
snapshot_data: Optional["SerializedData"] = None
serialized_data: Optional["SerializedData"] = None
Expand All @@ -316,7 +332,8 @@ def _assert(self, data: "SerializableData") -> bool:
not tainted
and snapshot_data is not None
and self.extension.matches(
serialized_data=serialized_data, snapshot_data=snapshot_data
serialized_data=serialized_data,
snapshot_data=snapshot_data,
)
)
assertion_success = matches
Expand Down
19 changes: 15 additions & 4 deletions src/syrupy/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ class SnapshotCollectionStorage(ABC):

@classmethod
def get_snapshot_name(
cls, *, test_location: "PyTestLocation", index: "SnapshotIndex" = 0
cls,
*,
test_location: "PyTestLocation",
index: "SnapshotIndex" = 0,
) -> str:
"""Get the snapshot name for the assertion index in a test location"""
index_suffix = ""
Expand Down Expand Up @@ -225,7 +228,11 @@ def _read_snapshot_collection(

@abstractmethod
def _read_snapshot_data_from_location(
self, *, snapshot_location: str, snapshot_name: str, session_id: str
self,
*,
snapshot_location: str,
snapshot_name: str,
session_id: str,
) -> Optional["SerializedData"]:
"""
Get only the snapshot data from location for assertion
Expand Down Expand Up @@ -259,15 +266,19 @@ class SnapshotReporter(ABC):
_context_line_count = 1

def diff_snapshots(
self, serialized_data: "SerializedData", snapshot_data: "SerializedData"
self,
serialized_data: "SerializedData",
snapshot_data: "SerializedData",
) -> "SerializedData":
env = {DISABLE_COLOR_ENV_VAR: "true"}
attrs = {"_context_line_count": 0}
with env_context(**env), obj_attrs(self, attrs):
return "\n".join(self.diff_lines(serialized_data, snapshot_data))

def diff_lines(
self, serialized_data: "SerializedData", snapshot_data: "SerializedData"
self,
serialized_data: "SerializedData",
snapshot_data: "SerializedData",
) -> Iterator[str]:
for line in self.__diff_lines(str(snapshot_data), str(serialized_data)):
yield reset(line)
Expand Down
1 change: 1 addition & 0 deletions src/syrupy/extensions/json/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def serialize(
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
**kwargs: Any,
) -> "SerializedData":
data = self._filter(
data=data,
Expand Down
19 changes: 15 additions & 4 deletions src/syrupy/extensions/single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ def serialize(

@classmethod
def get_snapshot_name(
cls, *, test_location: "PyTestLocation", index: "SnapshotIndex" = 0
cls,
*,
test_location: "PyTestLocation",
index: "SnapshotIndex" = 0,
) -> str:
return cls.__clean_filename(
AbstractSyrupyExtension.get_snapshot_name(
Expand All @@ -79,7 +82,9 @@ def dirname(cls, *, test_location: "PyTestLocation") -> str:
return str(Path(original_dirname).joinpath(test_location.basename))

def _read_snapshot_collection(
self, *, snapshot_location: str
self,
*,
snapshot_location: str,
) -> "SnapshotCollection":
file_ext_len = len(self._file_extension) + 1 if self._file_extension else 0
filename_wo_ext = snapshot_location[:-file_ext_len]
Expand All @@ -90,7 +95,11 @@ def _read_snapshot_collection(
return snapshot_collection

def _read_snapshot_data_from_location(
self, *, snapshot_location: str, snapshot_name: str, session_id: str
self,
*,
snapshot_location: str,
snapshot_name: str,
session_id: str,
) -> Optional["SerializableData"]:
try:
with open(
Expand All @@ -116,7 +125,9 @@ def get_write_encoding(cls) -> Optional[str]:

@classmethod
def _write_snapshot_collection(
cls, *, snapshot_collection: "SnapshotCollection"
cls,
*,
snapshot_collection: "SnapshotCollection",
) -> None:
filepath, data = (
snapshot_collection.location,
Expand Down
31 changes: 27 additions & 4 deletions src/syrupy/session.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pickle
from collections import defaultdict
from dataclasses import (
dataclass,
Expand Down Expand Up @@ -54,7 +55,7 @@ class SnapshotSession:
)

_queued_snapshot_writes: Dict[
Tuple[Type["AbstractSyrupyExtension"], str],
Tuple[Type["AbstractSyrupyExtension"], Optional[bytes], str],
List[Tuple["SerializedData", "PyTestLocation", "SnapshotIndex"]],
] = field(default_factory=dict)

Expand All @@ -68,19 +69,41 @@ def queue_snapshot_write(
snapshot_location = extension.get_location(
test_location=test_location, index=index
)
key = (extension.__class__, snapshot_location)

extension_context = getattr(extension, "context", None)

try:
extension_kwargs_bytes = (
pickle.dumps(extension_context) if extension_context else None
)
except pickle.PicklingError:
print("Extension context must be serializable.")
raise

key = (extension.__class__, extension_kwargs_bytes, snapshot_location)
queue = self._queued_snapshot_writes.get(key, [])
queue.append((data, test_location, index))
self._queued_snapshot_writes[key] = queue

def flush_snapshot_write_queue(self) -> None:
for (
extension_class,
extension_kwargs_bytes,
snapshot_location,
), queued_write in self._queued_snapshot_writes.items():
if queued_write:
extension_class.write_snapshot(
snapshot_location=snapshot_location, snapshots=queued_write
# It's possible to instantiate an extension with context. We need to
# ensure we never lose context between instantiations (since we may
# instantiate multiple times in a test session).
extension_kwargs = (
{"context": pickle.loads(extension_kwargs_bytes)}
if extension_kwargs_bytes
else {}
)
extension = extension_class(**extension_kwargs)
extension.write_snapshot(
snapshot_location=snapshot_location,
snapshots=queued_write,
)
self._queued_snapshot_writes = {}

Expand Down
6 changes: 4 additions & 2 deletions tests/examples/test_custom_snapshot_name.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
Example: Custom Snapshot Name
"""
from typing import Any

import pytest

from syrupy.extensions.amber import AmberSnapshotExtension
Expand All @@ -11,10 +13,10 @@
class CanadianNameExtension(AmberSnapshotExtension):
@classmethod
def get_snapshot_name(
cls, *, test_location: "PyTestLocation", index: "SnapshotIndex"
cls, *, test_location: "PyTestLocation", index: "SnapshotIndex", **kwargs: Any
) -> str:
original_name = AmberSnapshotExtension.get_snapshot_name(
test_location=test_location, index=index
test_location=test_location, index=index, **kwargs
)
return f"{original_name}🇨🇦"

Expand Down