Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/MaartenGr/KeyBERT.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMaarten Grootendorst <maarten_grootendorst@hotmail.com>2021-06-30 12:47:45 +0300
committerGitHub <noreply@github.com>2021-06-30 12:47:45 +0300
commit25dab3a763e5cc0450346dcf061754f71f184ae3 (patch)
tree6e734d3d57a732a3bc213cc45dd16707b71bf04a
parenteb6d0865c958a474f3554518539c7a37dbd9856b (diff)
v0.4 (#43)v0.4.0
* Use paraphrase-MiniLM-L6-v2 as the default embedding model * Highlight a document's keywords * Added FAQ
-rw-r--r--README.md21
-rw-r--r--docs/changelog.md14
-rw-r--r--docs/faq.md20
-rw-r--r--docs/guides/embeddings.md8
-rw-r--r--docs/guides/quickstart.md12
-rw-r--r--docs/index.md35
-rw-r--r--images/highlight.pngbin0 -> 21583 bytes
-rw-r--r--keybert/__init__.py4
-rw-r--r--keybert/_highlight.py96
-rw-r--r--keybert/_maxsum.py (renamed from keybert/maxsum.py)0
-rw-r--r--keybert/_mmr.py (renamed from keybert/mmr.py)0
-rw-r--r--keybert/_model.py (renamed from keybert/model.py)40
-rw-r--r--keybert/backend/_sentencetransformers.py8
-rw-r--r--keybert/backend/_utils.py7
-rw-r--r--mkdocs.yml1
-rw-r--r--setup.py4
-rw-r--r--tests/conftest.py8
-rw-r--r--tests/test_model.py47
18 files changed, 242 insertions, 83 deletions
diff --git a/README.md b/README.md
index 7659687..7e81848 100644
--- a/README.md
+++ b/README.md
@@ -90,8 +90,8 @@ from keybert import KeyBERT
doc = """
Supervised learning is the machine learning task of learning a function that
- maps an input to an output based on example input-output pairs.[1] It infers a
- function from labeled training data consisting of a set of training examples.[2]
+ maps an input to an output based on example input-output pairs. It infers a
+ function from labeled training data consisting of a set of training examples.
In supervised learning, each example is a pair consisting of an input object
(typically a vector) and a desired output value (also called the supervisory signal).
A supervised learning algorithm analyzes the training data and produces an inferred function,
@@ -100,7 +100,7 @@ doc = """
the learning algorithm to generalize from the training data to unseen situations in a
'reasonable' way (see inductive bias).
"""
-kw_model = KeyBERT('distilbert-base-nli-mean-tokens')
+kw_model = KeyBERT()
keywords = kw_model.extract_keywords(doc)
```
@@ -127,10 +127,17 @@ of words you would like in the resulting keyphrases:
('learning function', 0.5850)]
```
+We can highlight the keywords in the document by simply setting `hightlight`:
+```python
+keywords = kw_model.extract_keywords(doc, highlight=True)
+```
+<img src="images/highlight.png" width="75%" height="75%" />
+
+
**NOTE**: For a full overview of all possible transformer models see [sentence-transformer](https://www.sbert.net/docs/pretrained_models.html).
-I would advise either `'distilbert-base-nli-mean-tokens'` or `'xlm-r-distilroberta-base-paraphrase-v1'` as they
-have shown great performance in semantic similarity and paraphrase identification respectively.
+I would advise either `"paraphrase-MiniLM-L6-v2"` for English documents or `"paraphrase-multilingual-MiniLM-L12-v2"`
+for multi-lingual documents or any other language.
<a name="maxsum"/></a>
### 2.3. Max Sum Similarity
@@ -198,7 +205,7 @@ and pass it through KeyBERT with `model`:
```python
from keybert import KeyBERT
-kw_model = KeyBERT(model='distilbert-base-nli-mean-tokens')
+kw_model = KeyBERT(model='paraphrase-MiniLM-L6-v2')
```
Or select a SentenceTransformer model with your own parameters:
@@ -207,7 +214,7 @@ Or select a SentenceTransformer model with your own parameters:
from keybert import KeyBERT
from sentence_transformers import SentenceTransformer
-sentence_model = SentenceTransformer("distilbert-base-nli-mean-tokens", device="cpu")
+sentence_model = SentenceTransformer("paraphrase-MiniLM-L6-v2")
kw_model = KeyBERT(model=sentence_model)
```
diff --git a/docs/changelog.md b/docs/changelog.md
index 52e1b05..fc56e81 100644
--- a/docs/changelog.md
+++ b/docs/changelog.md
@@ -1,3 +1,17 @@
+## **Version 0.4.0**
+*Release date: 23 June, 2021*
+
+**Highlights**:
+
+* Highlight a document's keywords with:
+ * ```keywords = kw_model.extract_keywords(doc, highlight=True)```
+* Use `paraphrase-MiniLM-L6-v2` as the default embedder which gives great results!
+
+**Miscellaneous**:
+
+* Update Flair dependencies
+* Added FAQ
+
## **Version 0.3.0**
*Release date: 10 May, 2021*
diff --git a/docs/faq.md b/docs/faq.md
new file mode 100644
index 0000000..a883b58
--- /dev/null
+++ b/docs/faq.md
@@ -0,0 +1,20 @@
+## **Which embedding model works best for which language?**
+Unfortunately, there is not a definitive list of the best models for each language, this highly depends
+on your data, the model, and your specific use-case. However, the default model in KeyBERT
+(`"paraphrase-MiniLM-L6-v2"`) works great for **English** documents. In contrast, for **multi-lingual**
+documents or any other language, `"paraphrase-multilingual-MiniLM-L12-v2""` has shown great performance.
+
+If you want to use a model that provides a higher quality, but takes more compute time, then I would advise using `paraphrase-mpnet-base-v2` and `paraphrase-multilingual-mpnet-base-v2` instead.
+
+
+## **Should I preprocess the data?**
+No. By using document embeddings there is typically no need to preprocess the data as all parts of a document
+are important in understanding the general topic of the document. Although this holds true in 99% of cases, if you
+have data that contains a lot of noise, for example, HTML-tags, then it would be best to remove them. HTML-tags
+typically do not contribute to the meaning of a document and should therefore be removed. However, if you apply
+topic modeling to HTML-code to extract topics of code, then it becomes important.
+
+
+## **Can I use the GPU to speed up the model?**
+Yes! Since KeyBERT uses embeddings as its backend, a GPU is actually prefered when using this package.
+Although it is possible to use it without a dedicated GPU, the inference speed will be significantly slower. \ No newline at end of file
diff --git a/docs/guides/embeddings.md b/docs/guides/embeddings.md
index 3faedad..febe479 100644
--- a/docs/guides/embeddings.md
+++ b/docs/guides/embeddings.md
@@ -8,7 +8,7 @@ and pass it through KeyBERT with `model`:
```python
from keybert import KeyBERT
-kw_model = KeyBERT(model="xlm-r-bert-base-nli-stsb-mean-tokens")
+kw_model = KeyBERT(model="paraphrase-MiniLM-L6-v2")
```
Or select a SentenceTransformer model with your own parameters:
@@ -16,7 +16,7 @@ Or select a SentenceTransformer model with your own parameters:
```python
from sentence_transformers import SentenceTransformer
-sentence_model = SentenceTransformer("distilbert-base-nli-mean-tokens", device="cuda")
+sentence_model = SentenceTransformer("paraphrase-MiniLM-L6-v2")
kw_model = KeyBERT(model=sentence_model)
```
@@ -60,7 +60,7 @@ import spacy
nlp = spacy.load("en_core_web_md", exclude=['tagger', 'parser', 'ner', 'attribute_ruler', 'lemmatizer'])
-kw_model = KeyBERT(model=document_glove_embeddings)nlp
+kw_model = KeyBERT(model=nlp)
```
Using spacy-transformer models:
@@ -129,7 +129,7 @@ class CustomEmbedder(BaseEmbedder):
return embeddings
# Create custom backend
-distilbert = SentenceTransformer("distilbert-base-nli-stsb-mean-tokens")
+distilbert = SentenceTransformer("paraphrase-MiniLM-L6-v2")
custom_embedder = CustomEmbedder(embedding_model=distilbert)
# Pass custom backend to keybert
diff --git a/docs/guides/quickstart.md b/docs/guides/quickstart.md
index d5bf16b..ddf130a 100644
--- a/docs/guides/quickstart.md
+++ b/docs/guides/quickstart.md
@@ -38,7 +38,7 @@ doc = """
the learning algorithm to generalize from the training data to unseen situations in a
'reasonable' way (see inductive bias).
"""
-kw_model = KeyBERT('distilbert-base-nli-mean-tokens')
+kw_model = KeyBERT()
keywords = kw_model.extract_keywords(doc)
```
@@ -65,9 +65,15 @@ of words you would like in the resulting keyphrases:
('learning function', 0.5850)]
```
+We can highlight the keywords in the document by simply setting `hightlight`:
+
+```python
+keywords = kw_model.extract_keywords(doc, highlight=True)
+```
+
**NOTE**: For a full overview of all possible transformer models see [sentence-transformer](https://www.sbert.net/docs/pretrained_models.html).
-I would advise either `'distilbert-base-nli-mean-tokens'` or `'xlm-r-distilroberta-base-paraphrase-v1'` as they
-have shown great performance in semantic similarity and paraphrase identification respectively.
+I would advise either `"paraphrase-MiniLM-L6-v2"` for English documents or `"paraphrase-multilingual-MiniLM-L12-v2"`
+for multi-lingual documents or any other language.
### Max Sum Similarity
diff --git a/docs/index.md b/docs/index.md
index a3f5e64..1b3b053 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -7,7 +7,7 @@ create keywords and keyphrases that are most similar to a document.
## About the Project
-Although that are already many methods available for keyword generation
+Although there are already many methods available for keyword generation
(e.g.,
[Rake](https://github.com/aneesha/RAKE),
[YAKE!](https://github.com/LIAAD/yake), TF-IDF, etc.)
@@ -30,11 +30,6 @@ papers and solutions out there that use BERT-embeddings
), I could not find a BERT-based solution that did not have to be trained from scratch and
could be used for beginners (**correct me if I'm wrong!**).
Thus, the goal was a `pip install keybert` and at most 3 lines of code in usage.
-
-**NOTE**: If you use MMR to select the candidates instead of simple cosine similarity,
-this repo is essentially a simplified implementation of
-[EmbedRank](https://github.com/swisscom/ai-research-keyphrase-extraction)
-with BERT-embeddings.
## Installation
Installation can be done using [pypi](https://pypi.org/project/keybert/):
@@ -43,22 +38,33 @@ Installation can be done using [pypi](https://pypi.org/project/keybert/):
pip install keybert
```
-To use Flair embeddings, install KeyBERT as follows:
+You may want to install more depending on the transformers and language backends that you will be using. The possible installations are:
```
pip install keybert[flair]
+pip install keybert[gensim]
+pip install keybert[spacy]
+pip install keybert[use]
```
+To install all backends:
+
+```
+pip install keybert[all]
+```
+
+
## Usage
+
The most minimal example can be seen below for the extraction of keywords:
```python
from keybert import KeyBERT
doc = """
Supervised learning is the machine learning task of learning a function that
- maps an input to an output based on example input-output pairs.[1] It infers a
- function from labeled training data consisting of a set of training examples.[2]
+ maps an input to an output based on example input-output pairs. It infers a
+ function from labeled training data consisting of a set of training examples.
In supervised learning, each example is a pair consisting of an input object
(typically a vector) and a desired output value (also called the supervisory signal).
A supervised learning algorithm analyzes the training data and produces an inferred function,
@@ -67,13 +73,14 @@ doc = """
the learning algorithm to generalize from the training data to unseen situations in a
'reasonable' way (see inductive bias).
"""
-model = KeyBERT('distilbert-base-nli-mean-tokens')
+kw_model = KeyBERT()
+keywords = kw_model.extract_keywords(doc)
```
-You can set `keyphrase_length` to set the length of the resulting keyphras:
+You can set `keyphrase_ngram_range` to set the length of the resulting keywords/keyphrases:
```python
->>> model.extract_keywords(doc, keyphrase_ngram_range=(1, 1))
+>>> kw_model.extract_keywords(doc, keyphrase_ngram_range=(1, 1), stop_words=None)
[('learning', 0.4604),
('algorithm', 0.4556),
('training', 0.4487),
@@ -85,10 +92,10 @@ To extract keyphrases, simply set `keyphrase_ngram_range` to (1, 2) or higher de
of words you would like in the resulting keyphrases:
```python
->>> model.extract_keywords(doc, keyphrase_ngram_range=(1, 2))
+>>> kw_model.extract_keywords(doc, keyphrase_ngram_range=(1, 2), stop_words=None)
[('learning algorithm', 0.6978),
('machine learning', 0.6305),
('supervised learning', 0.5985),
('algorithm analyzes', 0.5860),
('learning function', 0.5850)]
-``` \ No newline at end of file
+```
diff --git a/images/highlight.png b/images/highlight.png
new file mode 100644
index 0000000..e6f1f85
--- /dev/null
+++ b/images/highlight.png
Binary files differ
diff --git a/keybert/__init__.py b/keybert/__init__.py
index 1f3b6e2..6d8bc87 100644
--- a/keybert/__init__.py
+++ b/keybert/__init__.py
@@ -1,3 +1,3 @@
-from keybert.model import KeyBERT
+from keybert._model import KeyBERT
-__version__ = "0.3.0"
+__version__ = "0.4.0"
diff --git a/keybert/_highlight.py b/keybert/_highlight.py
new file mode 100644
index 0000000..7ad98cf
--- /dev/null
+++ b/keybert/_highlight.py
@@ -0,0 +1,96 @@
+import re
+from rich.console import Console
+from rich.highlighter import RegexHighlighter
+from typing import Tuple, List
+
+
+class NullHighlighter(RegexHighlighter):
+ """Apply style to anything that looks like an email."""
+
+ base_style = ""
+ highlights = [r""]
+
+
+def highlight_document(doc: str,
+ keywords: List[Tuple[str, float]]):
+ """ Highlight keywords in a document
+
+ Arguments:
+ doc: The document for which to extract keywords/keyphrases
+ keywords: the top n keywords for a document with their respective distances
+ to the input document
+
+ Returns:
+ highlighted_text: The document with additional tags to highlight keywords
+ according to the rich package
+ """
+ keywords_only = [keyword for keyword, _ in keywords]
+ max_len = max([len(token.split(" ")) for token in keywords_only])
+
+ if max_len == 1:
+ highlighted_text = _highlight_one_gram(doc, keywords_only)
+ else:
+ highlighted_text = _highlight_n_gram(doc, keywords_only)
+
+ console = Console(highlighter=NullHighlighter())
+ console.print(highlighted_text)
+
+
+def _highlight_one_gram(doc: str,
+ keywords: List[str]) -> str:
+ """ Highlight 1-gram keywords in a document
+
+ Arguments:
+ doc: The document for which to extract keywords/keyphrases
+ keywords: the top n keywords for a document
+
+ Returns:
+ highlighted_text: The document with additional tags to highlight keywords
+ according to the rich package
+ """
+ tokens = re.sub(r' +', ' ', doc.replace("\n", " ")).split(" ")
+
+ highlighted_text = " ".join([f"[black on #FFFF00]{token}[/]"
+ if token.lower() in keywords
+ else f"{token}"
+ for token in tokens]).strip()
+ return highlighted_text
+
+
+def _highlight_n_gram(doc: str,
+ keywords: List[str]) -> str:
+ """ Highlight n-gram keywords in a document
+
+ Arguments:
+ doc: The document for which to extract keywords/keyphrases
+ keywords: the top n keywords for a document
+
+ Returns:
+ highlighted_text: The document with additional tags to highlight keywords
+ according to the rich package
+ """
+ max_len = max([len(token.split(" ")) for token in keywords])
+ tokens = re.sub(r' +', ' ', doc.replace("\n", " ")).strip().split(" ")
+ n_gram_tokens = [[" ".join(tokens[i: i + max_len][0: j + 1]) for j in range(max_len)] for i, _ in enumerate(tokens)]
+ highlighted_text = []
+ skip = False
+
+ for n_grams in n_gram_tokens:
+ candidate = False
+
+ if not skip:
+ for index, n_gram in enumerate(n_grams):
+
+ if n_gram.lower() in keywords:
+ candidate = f"[black on #FFFF00]{n_gram}[/]" + n_grams[-1].split(n_gram)[-1]
+ skip = index + 1
+
+ if not candidate:
+ candidate = n_grams[0]
+
+ highlighted_text.append(candidate)
+
+ else:
+ skip = skip - 1
+ highlighted_text = " ".join(highlighted_text)
+ return highlighted_text
diff --git a/keybert/maxsum.py b/keybert/_maxsum.py
index 336b168..336b168 100644
--- a/keybert/maxsum.py
+++ b/keybert/_maxsum.py
diff --git a/keybert/mmr.py b/keybert/_mmr.py
index 88c804a..88c804a 100644
--- a/keybert/mmr.py
+++ b/keybert/_mmr.py
diff --git a/keybert/model.py b/keybert/_model.py
index 586d216..497adaf 100644
--- a/keybert/model.py
+++ b/keybert/_model.py
@@ -8,8 +8,9 @@ from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import CountVectorizer
# KeyBERT
-from keybert.mmr import mmr
-from keybert.maxsum import max_sum_similarity
+from keybert._mmr import mmr
+from keybert._maxsum import max_sum_similarity
+from keybert._highlight import highlight_document
from keybert.backend._utils import select_backend
@@ -30,7 +31,7 @@ class KeyBERT:
"""
def __init__(self,
- model='distilbert-base-nli-mean-tokens'):
+ model="paraphrase-MiniLM-L6-v2"):
""" KeyBERT initialization
Arguments:
@@ -58,8 +59,9 @@ class KeyBERT:
use_mmr: bool = False,
diversity: float = 0.5,
nr_candidates: int = 20,
- vectorizer: CountVectorizer = None) -> Union[List[Tuple[str, float]],
- List[List[Tuple[str, float]]]]:
+ vectorizer: CountVectorizer = None,
+ highlight: bool = False) -> Union[List[Tuple[str, float]],
+ List[List[Tuple[str, float]]]]:
""" Extract keywords/keyphrases
NOTE:
@@ -94,6 +96,9 @@ class KeyBERT:
nr_candidates: The number of candidates to consider if use_maxsum is
set to True
vectorizer: Pass in your own CountVectorizer from scikit-learn
+ highlight: Whether to print the document and highlight
+ its keywords/keyphrases. NOTE: This does not work if
+ multiple documents are passed.
Returns:
keywords: the top n keywords for a document with their respective distances
@@ -102,16 +107,21 @@ class KeyBERT:
"""
if isinstance(docs, str):
- return self._extract_keywords_single_doc(doc=docs,
- candidates=candidates,
- keyphrase_ngram_range=keyphrase_ngram_range,
- stop_words=stop_words,
- top_n=top_n,
- use_maxsum=use_maxsum,
- use_mmr=use_mmr,
- diversity=diversity,
- nr_candidates=nr_candidates,
- vectorizer=vectorizer)
+ keywords = self._extract_keywords_single_doc(doc=docs,
+ candidates=candidates,
+ keyphrase_ngram_range=keyphrase_ngram_range,
+ stop_words=stop_words,
+ top_n=top_n,
+ use_maxsum=use_maxsum,
+ use_mmr=use_mmr,
+ diversity=diversity,
+ nr_candidates=nr_candidates,
+ vectorizer=vectorizer)
+ if highlight:
+ highlight_document(docs, keywords)
+
+ return keywords
+
elif isinstance(docs, list):
warnings.warn("Although extracting keywords for multiple documents is faster "
"than iterating over single documents, it requires significantly more memory "
diff --git a/keybert/backend/_sentencetransformers.py b/keybert/backend/_sentencetransformers.py
index 6b998f7..60e4845 100644
--- a/keybert/backend/_sentencetransformers.py
+++ b/keybert/backend/_sentencetransformers.py
@@ -16,13 +16,13 @@ class SentenceTransformerBackend(BaseEmbedder):
sentence-transformers model:
```python
from keybert.backend import SentenceTransformerBackend
- sentence_model = SentenceTransformerBackend("distilbert-base-nli-stsb-mean-tokens")
+ sentence_model = SentenceTransformerBackend("paraphrase-MiniLM-L6-v2")
```
or you can instantiate a model yourself:
```python
from keybert.backend import SentenceTransformerBackend
from sentence_transformers import SentenceTransformer
- embedding_model = SentenceTransformer("distilbert-base-nli-stsb-mean-tokens")
+ embedding_model = SentenceTransformer("paraphrase-MiniLM-L6-v2")
sentence_model = SentenceTransformerBackend(embedding_model)
```
"""
@@ -36,7 +36,7 @@ class SentenceTransformerBackend(BaseEmbedder):
else:
raise ValueError("Please select a correct SentenceTransformers model: \n"
"`from sentence_transformers import SentenceTransformer` \n"
- "`model = SentenceTransformer('distilbert-base-nli-stsb-mean-tokens')`")
+ "`model = SentenceTransformer('paraphrase-MiniLM-L6-v2')`")
def embed(self,
documents: List[str],
@@ -51,4 +51,4 @@ class SentenceTransformerBackend(BaseEmbedder):
that each have an embeddings size of `m`
"""
embeddings = self.embedding_model.encode(documents, show_progress_bar=verbose)
- return embeddings \ No newline at end of file
+ return embeddings
diff --git a/keybert/backend/_utils.py b/keybert/backend/_utils.py
index 0e89bf1..4d13512 100644
--- a/keybert/backend/_utils.py
+++ b/keybert/backend/_utils.py
@@ -4,8 +4,9 @@ from ._sentencetransformers import SentenceTransformerBackend
def select_backend(embedding_model) -> BaseEmbedder:
""" Select an embedding model based on language or a specific sentence transformer models.
- When selecting a language, we choose distilbert-base-nli-stsb-mean-tokens for English and
- xlm-r-bert-base-nli-stsb-mean-tokens for all other languages as it support 100+ languages.
+ When selecting a language, we choose `paraphrase-MiniLM-L6-v2` for English and
+ `paraphrase-multilingual-MiniLM-L12-v2` for all other languages as it support 100+ languages.
+
Returns:
model: Either a Sentence-Transformer or Flair model
"""
@@ -41,4 +42,4 @@ def select_backend(embedding_model) -> BaseEmbedder:
if isinstance(embedding_model, str):
return SentenceTransformerBackend(embedding_model)
- return SentenceTransformerBackend("xlm-r-bert-base-nli-stsb-mean-tokens") \ No newline at end of file
+ return SentenceTransformerBackend("paraphrase-multilingual-MiniLM-L12-v2")
diff --git a/mkdocs.yml b/mkdocs.yml
index 2fdfcb5..ec906cd 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -14,6 +14,7 @@ nav:
- KeyBERT: api/keybert.md
- MMR: api/mmr.md
- MaxSum: api/maxsum.md
+ - FAQ: faq.md
- Changelog: changelog.md
plugins:
- mkdocstrings:
diff --git a/setup.py b/setup.py
index f2a6865..2527b71 100644
--- a/setup.py
+++ b/setup.py
@@ -9,6 +9,7 @@ base_packages = [
"sentence-transformers>=0.3.8",
"scikit-learn>=0.22.2",
"numpy>=1.18.5",
+ "rich>=10.4.0"
]
docs_packages = [
@@ -18,6 +19,7 @@ docs_packages = [
]
flair_packages = [
+ "transformers==3.5.1",
"torch>=1.4.0,<1.7.1",
"flair==0.7"
]
@@ -46,7 +48,7 @@ with open("README.md", "r", encoding='utf-8') as fh:
setup(
name="keybert",
packages=find_packages(exclude=["notebooks", "docs"]),
- version="0.3.0",
+ version="0.4.0",
author="Maarten Grootendorst",
author_email="maartengrootendorst@gmail.com",
description="KeyBERT performs keyword extraction with state-of-the-art transformer models.",
diff --git a/tests/conftest.py b/tests/conftest.py
deleted file mode 100644
index 50fc046..0000000
--- a/tests/conftest.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from keybert import KeyBERT
-import pytest
-
-
-@pytest.fixture(scope="module")
-def base_keybert():
- model = KeyBERT(model = 'distilbert-base-nli-mean-tokens')
- return model
diff --git a/tests/test_model.py b/tests/test_model.py
index 9362173..977f98e 100644
--- a/tests/test_model.py
+++ b/tests/test_model.py
@@ -1,21 +1,24 @@
import pytest
from .utils import get_test_data
from sklearn.feature_extraction.text import CountVectorizer
+from keybert import KeyBERT
doc_one, doc_two = get_test_data()
+model = KeyBERT(model='paraphrase-MiniLM-L6-v2')
@pytest.mark.parametrize("keyphrase_length", [(1, i+1) for i in range(5)])
@pytest.mark.parametrize("vectorizer", [None, CountVectorizer(ngram_range=(1, 1), stop_words="english")])
-def test_single_doc(keyphrase_length, vectorizer, base_keybert):
+def test_single_doc(keyphrase_length, vectorizer):
""" Test whether the keywords are correctly extracted """
top_n = 5
- keywords = base_keybert.extract_keywords(doc_one,
- keyphrase_ngram_range=keyphrase_length,
- min_df=1,
- top_n=top_n,
- vectorizer=vectorizer)
+ keywords = model.extract_keywords(doc_one,
+ keyphrase_ngram_range=keyphrase_length,
+ min_df=1,
+ top_n=top_n,
+ vectorizer=vectorizer)
+
assert isinstance(keywords, list)
assert isinstance(keywords[0], tuple)
assert isinstance(keywords[0][0], str)
@@ -29,16 +32,16 @@ def test_single_doc(keyphrase_length, vectorizer, base_keybert):
for i in range(4)
for truth in [True, False]])
@pytest.mark.parametrize("vectorizer", [None, CountVectorizer(ngram_range=(1, 1), stop_words="english")])
-def test_extract_keywords_single_doc(keyphrase_length, mmr, maxsum, vectorizer, base_keybert):
+def test_extract_keywords_single_doc(keyphrase_length, mmr, maxsum, vectorizer):
""" Test extraction of protected single document method """
top_n = 5
- keywords = base_keybert._extract_keywords_single_doc(doc_one,
- top_n=top_n,
- keyphrase_ngram_range=keyphrase_length,
- use_mmr=mmr,
- use_maxsum=maxsum,
- diversity=0.5,
- vectorizer=vectorizer)
+ keywords = model._extract_keywords_single_doc(doc_one,
+ top_n=top_n,
+ keyphrase_ngram_range=keyphrase_length,
+ use_mmr=mmr,
+ use_maxsum=maxsum,
+ diversity=0.5,
+ vectorizer=vectorizer)
assert isinstance(keywords, list)
assert isinstance(keywords[0][0], str)
assert isinstance(keywords[0][1], float)
@@ -48,12 +51,12 @@ def test_extract_keywords_single_doc(keyphrase_length, mmr, maxsum, vectorizer,
@pytest.mark.parametrize("keyphrase_length", [(1, i+1) for i in range(5)])
-def test_extract_keywords_multiple_docs(keyphrase_length, base_keybert):
+def test_extract_keywords_multiple_docs(keyphrase_length):
""" Test extractino of protected multiple document method"""
top_n = 5
- keywords_list = base_keybert._extract_keywords_multiple_docs([doc_one, doc_two],
- top_n=top_n,
- keyphrase_ngram_range=keyphrase_length)
+ keywords_list = model._extract_keywords_multiple_docs([doc_one, doc_two],
+ top_n=top_n,
+ keyphrase_ngram_range=keyphrase_length)
assert isinstance(keywords_list, list)
assert isinstance(keywords_list[0], list)
assert len(keywords_list) == 2
@@ -65,16 +68,16 @@ def test_extract_keywords_multiple_docs(keyphrase_length, base_keybert):
assert len(keyword[0].split(" ")) <= keyphrase_length[1]
-def test_error(base_keybert):
+def test_error():
""" Empty doc should raise a ValueError """
with pytest.raises(AttributeError):
doc = []
- base_keybert._extract_keywords_single_doc(doc)
+ model._extract_keywords_single_doc(doc)
-def test_empty_doc(base_keybert):
+def test_empty_doc():
""" Test empty document """
doc = ""
- result = base_keybert._extract_keywords_single_doc(doc)
+ result = model._extract_keywords_single_doc(doc)
assert result == []