Skip to content

Commit

Permalink
Add diversity parameter in max sum sim (#7)
Browse files Browse the repository at this point in the history
* Add diversity parameter in max sum sim
* Remove 3.7 testing due to timeout errors
  • Loading branch information
MaartenGr committed Oct 28, 2020
1 parent a67fba5 commit 8fd836c
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.6, 3.7, 3.8]
python-version: [3.6, 3.8]

steps:
- uses: actions/checkout@v2
Expand Down
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,12 @@ Then, we take all top_n combinations from the 2 x top_n words and extract the co
that are the least similar to each other by cosine similarity.

```python
>>> model.extract_keywords(doc, keyphrase_length=3, stop_words='english', use_maxsum=True)
['signal supervised learning',
>>> model.extract_keywords(doc, keyphrase_length=3, stop_words='english',
use_maxsum=True, nr_candidates=20, top_n=5)
['set training examples',
'generalize training data',
'requires learning algorithm',
'learning function maps',
'algorithm analyzes training',
'superivsed learning algorithm',
'learning machine learning']
```

Expand Down
9 changes: 7 additions & 2 deletions keybert/maxsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
def max_sum_similarity(doc_embedding: np.ndarray,
word_embeddings: np.ndarray,
words: List[str],
top_n: int) -> List[str]:
top_n: int,
nr_candidates: int) -> List[str]:
""" Calculate Max Sum Distance for extraction of keywords
We take the 2 x top_n most similar words/phrases to the document.
Expand All @@ -23,17 +24,21 @@ def max_sum_similarity(doc_embedding: np.ndarray,
word_embeddings: The embeddings of the selected candidate keywords/phrases
words: The selected candidate keywords/keyphrases
top_n: The number of keywords/keyhprases to return
nr_candidates: The number of candidates to consider
Returns:
List[str]: The selected keywords/keyphrases
"""
if nr_candidates < top_n:
raise Exception("Make sure that the number of candidates exceeds the number "
"of keywords to return.")

# Calculate distances and extract keywords
distances = cosine_similarity(doc_embedding, word_embeddings)
distances_words = cosine_similarity(word_embeddings, word_embeddings)

# Get 2*top_n words as candidates based on cosine similarity
words_idx = list(distances.argsort()[0][-top_n*2:])
words_idx = list(distances.argsort()[0][-nr_candidates:])
words_vals = [words[index] for index in words_idx]
candidates = distances_words[np.ix_(words_idx, words_idx)]

Expand Down
14 changes: 10 additions & 4 deletions keybert/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def extract_keywords(self,
min_df: int = 1,
use_maxsum: bool = False,
use_mmr: bool = False,
diversity: float = 0.5) -> Union[List[str], List[List[str]]]:
diversity: float = 0.5,
nr_candidates: int = 20) -> Union[List[str], List[List[str]]]:
""" Extract keywords/keyphrases
NOTE:
Expand Down Expand Up @@ -72,6 +73,8 @@ def extract_keywords(self,
selection of keywords/keyphrases
diversity: The diversity of the results between 0 and 1 if use_mmr
is set to True
nr_candidates: The number of candidates to consider if use_maxsum is
set to True
Returns:
keywords: the top n keywords for a document
Expand All @@ -85,7 +88,8 @@ def extract_keywords(self,
top_n,
use_maxsum,
use_mmr,
diversity)
diversity,
nr_candidates)
elif isinstance(docs, list):
warnings.warn("Although extracting keywords for multiple documents is faster "
"than iterating over single documents, it requires significant memory "
Expand All @@ -103,7 +107,8 @@ def _extract_keywords_single_doc(self,
top_n: int = 5,
use_maxsum: bool = False,
use_mmr: bool = False,
diversity: float = 0.5) -> List[str]:
diversity: float = 0.5,
nr_candidates: int = 20) -> List[str]:
""" Extract keywords/keyphrases for a single document
Arguments:
Expand All @@ -114,6 +119,7 @@ def _extract_keywords_single_doc(self,
use_mmr: Whether to use Max Sum Similarity
use_mmr: Whether to use MMR
diversity: The diversity of results between 0 and 1 if use_mmr is True
nr_candidates: The number of candidates to consider if use_maxsum is set to True
Returns:
keywords: The top n keywords for a document
Expand All @@ -133,7 +139,7 @@ def _extract_keywords_single_doc(self,
if use_mmr:
keywords = mmr(doc_embedding, word_embeddings, words, top_n, diversity)
elif use_maxsum:
keywords = max_sum_similarity(doc_embedding, word_embeddings, words, top_n)
keywords = max_sum_similarity(doc_embedding, word_embeddings, words, top_n, nr_candidates)
else:
distances = cosine_similarity(doc_embedding, word_embeddings)
keywords = [words[index] for index in distances.argsort()[0][-top_n:]][::-1]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
setuptools.setup(
name="keybert",
packages=["keybert"],
version="0.1.1",
version="0.1.2",
author="Maarten Grootendorst",
author_email="maartengrootendorst@gmail.com",
description="KeyBERT performs keyword extraction with state-of-the-art transformer models.",
Expand Down

0 comments on commit 8fd836c

Please sign in to comment.