Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/MaartenGr/KeyBERT.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMaartenGr <maarten_grootendorst@hotmail.com>2020-12-02 12:58:09 +0300
committerMaartenGr <maarten_grootendorst@hotmail.com>2020-12-02 12:58:09 +0300
commit8579108919ff044ca5849c388275cc3addd628f8 (patch)
tree93f28d8b9569c2580b497844ebf09dbe34b2b135
parent43fed7547b8ff121ab738205a13da20be761c444 (diff)
Add custom countvectorizerfeature-ngram
-rw-r--r--keybert/__init__.py1
-rw-r--r--keybert/model.py50
-rw-r--r--setup.py2
-rw-r--r--tests/test_model.py33
4 files changed, 54 insertions, 32 deletions
diff --git a/keybert/__init__.py b/keybert/__init__.py
index 1d41a2e..b6e17af 100644
--- a/keybert/__init__.py
+++ b/keybert/__init__.py
@@ -1 +1,2 @@
from keybert.model import KeyBERT
+__version__ = "0.1.3"
diff --git a/keybert/model.py b/keybert/model.py
index e890de3..5a16189 100644
--- a/keybert/model.py
+++ b/keybert/model.py
@@ -3,7 +3,7 @@ from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import CountVectorizer
from tqdm import tqdm
-from typing import List, Union
+from typing import List, Union, Tuple
import warnings
from .mmr import mmr
from .maxsum import max_sum_similarity
@@ -35,14 +35,15 @@ class KeyBERT:
def extract_keywords(self,
docs: Union[str, List[str]],
- keyphrase_length: int = 1,
+ keyphrase_ngram_range: Tuple[int, int] = (1, 1),
stop_words: Union[str, List[str]] = 'english',
top_n: int = 5,
min_df: int = 1,
use_maxsum: bool = False,
use_mmr: bool = False,
diversity: float = 0.5,
- nr_candidates: int = 20) -> Union[List[str], List[List[str]]]:
+ nr_candidates: int = 20,
+ vectorizer: CountVectorizer = None) -> Union[List[str], List[List[str]]]:
""" Extract keywords/keyphrases
NOTE:
@@ -62,7 +63,7 @@ class KeyBERT:
Arguments:
docs: The document(s) for which to extract keywords/keyphrases
- keyphrase_length: Length, in words, of the extracted keywords/keyphrases
+ keyphrase_ngram_range: Length, in words, of the extracted keywords/keyphrases
stop_words: Stopwords to remove from the document
top_n: Return the top n keywords/keyphrases
min_df: Minimum document frequency of a word across all documents
@@ -75,6 +76,7 @@ class KeyBERT:
is set to True
nr_candidates: The number of candidates to consider if use_maxsum is
set to True
+ vectorizer: Pass in your own CountVectorizer from scikit-learn
Returns:
keywords: the top n keywords for a document
@@ -83,43 +85,47 @@ class KeyBERT:
if isinstance(docs, str):
return self._extract_keywords_single_doc(docs,
- keyphrase_length,
+ keyphrase_ngram_range,
stop_words,
top_n,
use_maxsum,
use_mmr,
diversity,
- nr_candidates)
+ nr_candidates,
+ vectorizer)
elif isinstance(docs, list):
warnings.warn("Although extracting keywords for multiple documents is faster "
- "than iterating over single documents, it requires significant memory "
+ "than iterating over single documents, it requires significantly more memory "
"to hold all word embeddings. Use this at your own discretion!")
return self._extract_keywords_multiple_docs(docs,
- keyphrase_length,
+ keyphrase_ngram_range,
stop_words,
top_n,
- min_df=min_df)
+ min_df,
+ vectorizer)
def _extract_keywords_single_doc(self,
doc: str,
- keyphrase_length: int = 1,
+ keyphrase_ngram_range: Tuple[int, int] = (1, 1),
stop_words: Union[str, List[str]] = 'english',
top_n: int = 5,
use_maxsum: bool = False,
use_mmr: bool = False,
diversity: float = 0.5,
- nr_candidates: int = 20) -> List[str]:
+ nr_candidates: int = 20,
+ vectorizer: CountVectorizer = None) -> List[str]:
""" Extract keywords/keyphrases for a single document
Arguments:
doc: The document for which to extract keywords/keyphrases
- keyphrase_length: Length, in words, of the extracted keywords/keyphrases
+ keyphrase_ngram_range: Length, in words, of the extracted keywords/keyphrases
stop_words: Stopwords to remove from the document
top_n: Return the top n keywords/keyphrases
use_mmr: Whether to use Max Sum Similarity
use_mmr: Whether to use MMR
diversity: The diversity of results between 0 and 1 if use_mmr is True
nr_candidates: The number of candidates to consider if use_maxsum is set to True
+ vectorizer: Pass in your own CountVectorizer from scikit-learn
Returns:
keywords: The top n keywords for a document
@@ -127,8 +133,10 @@ class KeyBERT:
"""
try:
# Extract Words
- n_gram_range = (keyphrase_length, keyphrase_length)
- count = CountVectorizer(ngram_range=n_gram_range, stop_words=stop_words).fit([doc])
+ if vectorizer:
+ count = vectorizer.fit([doc])
+ else:
+ count = CountVectorizer(ngram_range=keyphrase_ngram_range, stop_words=stop_words).fit([doc])
words = count.get_feature_names()
# Extract Embeddings
@@ -150,28 +158,32 @@ class KeyBERT:
def _extract_keywords_multiple_docs(self,
docs: List[str],
- keyphrase_length: int = 1,
+ keyphrase_ngram_range: Tuple[int, int] = (1, 1),
stop_words: str = 'english',
top_n: int = 5,
- min_df: int = 1):
+ min_df: int = 1,
+ vectorizer: CountVectorizer = None):
""" Extract keywords/keyphrases for a multiple documents
This currently does not use MMR as
Arguments:
docs: The document for which to extract keywords/keyphrases
- keyphrase_length: Length, in words, of the extracted keywords/keyphrases
+ keyphrase_ngram_range: Length, in words, of the extracted keywords/keyphrases
stop_words: Stopwords to remove from the document
top_n: Return the top n keywords/keyphrases
min_df: The minimum frequency of words
+ vectorizer: Pass in your own CountVectorizer from scikit-learn
Returns:
keywords: The top n keywords for a document
"""
# Extract words
- n_gram_range = (keyphrase_length, keyphrase_length)
- count = CountVectorizer(ngram_range=n_gram_range, stop_words=stop_words, min_df=min_df).fit(docs)
+ if vectorizer:
+ count = vectorizer.fit(docs)
+ else:
+ count = CountVectorizer(ngram_range=keyphrase_ngram_range, stop_words=stop_words, min_df=min_df).fit(docs)
words = count.get_feature_names()
df = count.transform(docs)
diff --git a/setup.py b/setup.py
index 060ba0c..7a4b5ff 100644
--- a/setup.py
+++ b/setup.py
@@ -25,7 +25,7 @@ with open("README.md", "r") as fh:
setuptools.setup(
name="keybert",
packages=["keybert"],
- version="0.1.2",
+ version="0.1.3",
author="Maarten Grootendorst",
author_email="maartengrootendorst@gmail.com",
description="KeyBERT performs keyword extraction with state-of-the-art transformer models.",
diff --git a/tests/test_model.py b/tests/test_model.py
index 2590b69..c03fd48 100644
--- a/tests/test_model.py
+++ b/tests/test_model.py
@@ -1,47 +1,56 @@
import pytest
from .utils import get_test_data
+from sklearn.feature_extraction.text import CountVectorizer
doc_one, doc_two = get_test_data()
-@pytest.mark.parametrize("keyphrase_length", [i+1 for i in range(5)])
-def test_single_doc(keyphrase_length, base_keybert):
+@pytest.mark.parametrize("keyphrase_length", [(1, i+1) for i in range(5)])
+@pytest.mark.parametrize("vectorizer", [None, CountVectorizer(ngram_range=(1, 1), stop_words="english")])
+def test_single_doc(keyphrase_length, vectorizer, base_keybert):
""" Test whether the keywords are correctly extracted """
top_n = 5
- keywords = base_keybert.extract_keywords(doc_one, keyphrase_length=keyphrase_length, min_df=1, top_n=top_n)
+
+ keywords = base_keybert.extract_keywords(doc_one,
+ keyphrase_ngram_range=keyphrase_length,
+ min_df=1,
+ top_n=top_n,
+ vectorizer=vectorizer)
assert isinstance(keywords, list)
assert isinstance(keywords[0], str)
assert len(keywords) == top_n
for keyword in keywords:
- assert len(keyword.split(" ")) == keyphrase_length
+ assert len(keyword.split(" ")) <= keyphrase_length[1]
-@pytest.mark.parametrize("keyphrase_length, mmr, maxsum", [(i+1, truth, not truth)
+@pytest.mark.parametrize("keyphrase_length, mmr, maxsum", [((1, i+1), truth, not truth)
for i in range(4)
for truth in [True, False]])
-def test_extract_keywords_single_doc(keyphrase_length, mmr, maxsum, base_keybert):
+@pytest.mark.parametrize("vectorizer", [None, CountVectorizer(ngram_range=(1, 1), stop_words="english")])
+def test_extract_keywords_single_doc(keyphrase_length, mmr, maxsum, vectorizer, base_keybert):
""" Test extraction of protected single document method """
top_n = 5
keywords = base_keybert._extract_keywords_single_doc(doc_one,
top_n=top_n,
- keyphrase_length=keyphrase_length,
+ keyphrase_ngram_range=keyphrase_length,
use_mmr=mmr,
use_maxsum=maxsum,
- diversity=0.5)
+ diversity=0.5,
+ vectorizer=vectorizer)
assert isinstance(keywords, list)
assert isinstance(keywords[0], str)
assert len(keywords) == top_n
for keyword in keywords:
- assert len(keyword.split(" ")) == keyphrase_length
+ assert len(keyword.split(" ")) <= keyphrase_length[1]
-@pytest.mark.parametrize("keyphrase_length", [i+1 for i in range(5)])
+@pytest.mark.parametrize("keyphrase_length", [(1, i+1) for i in range(5)])
def test_extract_keywords_multiple_docs(keyphrase_length, base_keybert):
""" Test extractino of protected multiple document method"""
top_n = 5
keywords_list = base_keybert._extract_keywords_multiple_docs([doc_one, doc_two],
top_n=top_n,
- keyphrase_length=keyphrase_length)
+ keyphrase_ngram_range=keyphrase_length)
assert isinstance(keywords_list, list)
assert isinstance(keywords_list[0], list)
assert len(keywords_list) == 2
@@ -50,7 +59,7 @@ def test_extract_keywords_multiple_docs(keyphrase_length, base_keybert):
assert len(keywords) == top_n
for keyword in keywords:
- assert len(keyword.split(" ")) == keyphrase_length
+ assert len(keyword.split(" ")) <= keyphrase_length[1]
def test_error(base_keybert):