Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Taqi Jaffri committed Mar 13, 2024
1 parent db5f98a commit 9592d58
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 12 deletions.
6 changes: 3 additions & 3 deletions docugami_dfm_benchmarks/utils/similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sentence_transformers import SentenceTransformer, util
from torch.types import Number

from docugami_dfm_benchmarks.utils.text import get_tokens
from docugami_dfm_benchmarks.utils.text import get_tokens, normalize

embedding_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")

Expand All @@ -16,8 +16,8 @@ def semantic_similarity(text1: str, text2: str) -> Number:


def compute_f1(text1: str, text2: str) -> float:
gold_toks = get_tokens(text1)
pred_toks = get_tokens(text2)
gold_toks = get_tokens(normalize(text1))
pred_toks = get_tokens(normalize(text2))
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
num_same = sum(common.values())
if len(gold_toks) == 0 or len(pred_toks) == 0:
Expand Down
15 changes: 15 additions & 0 deletions tests/utils/test_similarity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pytest

from docugami_dfm_benchmarks.utils.similarity import compute_f1


@pytest.mark.parametrize(
"text1,text2,expected_f1",
[
("This is a test", "This is a test", 1.0), # Exact match
("One two a three", "one two three", 1.0), # Exact match modulo article, whitespace and casing
("One two a three", " four five a six", 0.0), # No match
],
)
def test_compute_f1(text1: str, text2: str, expected_f1: float) -> None:
assert compute_f1(text1, text2) == expected_f1
17 changes: 8 additions & 9 deletions tests/utils/test_text.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,40 @@
import pytest
from docugami_dfm_benchmarks.utils.text import normalize, get_tokens
from docugami_dfm_benchmarks.utils.text import get_tokens, normalize


def test_normalize_basic():
def test_normalize_basic() -> None:
"""Test normalization on a simple string."""
assert normalize("This is an example.") == "this is example"


def test_normalize_with_punctuation():
def test_normalize_with_punctuation() -> None:
"""Test normalization removes punctuation."""
assert normalize("Hello, world!") == "hello world"


def test_normalize_with_articles():
def test_normalize_with_articles() -> None:
"""Test normalization removes articles 'a', 'an', 'the'."""
assert (
normalize("A quick brown fox jumps over the lazy dog.")
== "quick brown fox jumps over lazy dog"
)


def test_normalize_with_extra_whitespace():
def test_normalize_with_extra_whitespace() -> None:
"""Test normalization removes extra whitespace."""
assert normalize(" This is a test. ") == "this is test"


def test_get_tokens_empty():
def test_get_tokens_empty() -> None:
"""Test get_tokens returns an empty list for empty input."""
assert get_tokens("") == []


def test_get_tokens_basic():
def test_get_tokens_basic() -> None:
"""Test get_tokens on a simple string."""
assert get_tokens("This is a test.") == ["this", "is", "test"]


def test_get_tokens_complex():
def test_get_tokens_complex() -> None:
"""Test get_tokens with punctuation and extra whitespace."""
expected = [
"this",
Expand Down

0 comments on commit 9592d58

Please sign in to comment.