Skip to content

Commit

Permalink
Merge pull request #89 from ebotiab/add-kwargs-wrappers
Browse files Browse the repository at this point in the history
Add kwargs in model evaluators and analyzer wrappers predict method
  • Loading branch information
omri374 committed Nov 15, 2023
2 parents 804b608 + 36e86cb commit 40a83fd
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
2 changes: 1 addition & 1 deletion presidio_evaluator/models/base_model.py
Expand Up @@ -33,7 +33,7 @@ def __init__(
self.verbose = verbose

@abstractmethod
def predict(self, sample: InputSample) -> List[str]:
def predict(self, sample: InputSample, **kwargs) -> List[str]:
"""
Abstract. Returns the predicted tokens/spans from the evaluated model
:param sample: Sample to be evaluated
Expand Down
25 changes: 20 additions & 5 deletions presidio_evaluator/models/presidio_analyzer_wrapper.py
@@ -1,6 +1,6 @@
from typing import List, Optional, Dict

from presidio_analyzer import AnalyzerEngine
from presidio_analyzer import AnalyzerEngine, EntityRecognizer

from presidio_evaluator import InputSample, span_to_tag
from presidio_evaluator.models import BaseModel
Expand All @@ -16,6 +16,9 @@ def __init__(
score_threshold: float = 0.4,
language: str = "en",
entity_mapping: Optional[Dict[str, str]] = None,
ad_hoc_recognizers: Optional[List[EntityRecognizer]] = None,
context: Optional[List[str]] = None,
allow_list: Optional[List[str]] = None,
):
"""
Evaluation wrapper for the Presidio Analyzer
Expand All @@ -29,25 +32,37 @@ def __init__(
)
self.score_threshold = score_threshold
self.language = language
self.ad_hoc_recognizers = ad_hoc_recognizers
self.context = context
self.allow_list = allow_list

if not analyzer_engine:
analyzer_engine = AnalyzerEngine()
self._update_recognizers_based_on_entities_to_keep(analyzer_engine)
self.analyzer_engine = analyzer_engine

def predict(self, sample: InputSample) -> List[str]:
def predict(self, sample: InputSample, **kwargs) -> List[str]:
language = kwargs.get("language", self.language)
score_threshold = kwargs.get("score_threshold", self.score_threshold)
ad_hoc_recognizers = kwargs.get("ad_hoc_recognizers", self.ad_hoc_recognizers)
context = kwargs.get("context", self.context)
allow_list = kwargs.get("allow_list", self.allow_list)

results = self.analyzer_engine.analyze(
text=sample.full_text,
entities=self.entities,
language=self.language,
score_threshold=self.score_threshold,
language=language,
score_threshold=score_threshold,
ad_hoc_recognizers=ad_hoc_recognizers,
context=context,
allow_list=allow_list,
**kwargs,
)
starts = []
ends = []
scores = []
tags = []
#

for res in results:
starts.append(res.start)
ends.append(res.end)
Expand Down

0 comments on commit 40a83fd

Please sign in to comment.