diff options
author | Maarten Grootendorst <maarten_grootendorst@hotmail.com> | 2020-10-27 11:02:11 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-10-27 11:02:11 +0300 |
commit | 222cc5b96efacee01def4dc0b4c366be314a8ec1 (patch) | |
tree | 3cf82bd06feb47d78c3646c2e69c835805ddb4f5 | |
parent | 5d408d7db170f5af803ef13dd57fd9e18236697e (diff) |
Unit tests (#1, #3)
* Added unit tests
* Added documentation
-rw-r--r-- | .github/workflows/testing.yml | 31 | ||||
-rw-r--r-- | README.md | 136 | ||||
-rw-r--r-- | docs/algorithm.md | 1 | ||||
-rw-r--r-- | docs/img/icon.png | bin | 13661 -> 0 bytes | |||
-rw-r--r-- | images/icon.png | bin | 15208 -> 21723 bytes | |||
-rw-r--r-- | images/logo.png | bin | 17931 -> 36703 bytes | |||
-rw-r--r-- | keybert/__init__.py | 1 | ||||
-rw-r--r-- | keybert/mmr.py | 92 | ||||
-rw-r--r-- | keybert/model.py | 29 | ||||
-rw-r--r-- | mkdocs.yml | 2 | ||||
-rw-r--r-- | tests/conftest.py | 23 | ||||
-rw-r--r-- | tests/test_model.py | 81 | ||||
-rw-r--r-- | tests/utils.py | 24 |
13 files changed, 377 insertions, 43 deletions
diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml new file mode 100644 index 0000000..84ede17 --- /dev/null +++ b/.github/workflows/testing.yml @@ -0,0 +1,31 @@ +name: Code Checks + +on: + push: + branches: + - master + - dev + pull_request: + branches: + - master + - dev + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.6, 3.7, 3.8] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[test]" + - name: Run Checking Mechanisms + run: make check @@ -3,4 +3,138 @@ [![PyPI - PyPi](https://img.shields.io/pypi/v/keyBERT)](https://pypi.org/project/keybert/) [![Build](https://img.shields.io/github/workflow/status/MaartenGr/keyBERT/Code%20Checks/master)](https://pypi.org/project/keybert/) -# KeyBERT
\ No newline at end of file +<img src="images/logo.png" width="35%" height="35%" align="right" /> + +# KeyBERT + +KeyBERT is a minimal and easy-to-use keyword extraction technique that leverages BERT embeddings to +create keywords and keyphrases that are most similar to a document. + +Corresponding medium post can be found [here](). + +<a name="toc"/></a> +## Table of Contents +<!--ts--> + 1. [About the Project](#about) + 2. [Getting Started](#gettingstarted) + 2.1. [Installation](#installation) + 2.2. [Basic Usage](#usage) +<!--te--> + + +<a name="about"/></a> +## 1. About the Project +[Back to ToC](#toc) + +Although that are already many methods available for keyword generation +(e.g., +[Rake](https://github.com/aneesha/RAKE), +[YAKE!](https://github.com/LIAAD/yake), TF-IDF, etc.) +I wanted to create a very basic, but powerful method for extracting keywords and keyphrases. +This is where **KeyBERT** comes in! Which uses BERT-embeddings and simple cosine similarity +to find the sub-phrases in a document that are the most similar to the document itself. + +First, document embeddings are extracted with BERT to get a document-level representation. +Then, word embeddings are extracted for N-gram words/phrases. Finally, we use cosine similarity +to find the words/phrases that are the most similar to the document. The most similar words could +then be identified as the words that best describe the entire document. + +KeyBERT is by no means unique and is created as a quick and easy method +for creating keywords and keyphrases. Although there are many great +papers and solutions out there that use BERT-embeddings +(e.g., +[1](https://github.com/pranav-ust/BERT-keyphrase-extraction), +[2](https://github.com/ibatra/BERT-Keyword-Extractor), +[3](https://www.preprints.org/manuscript/201908.0073/download/final_file), +), I could not find a BERT-based solution that did not have to be trained from scratch and +could be used for beginners (**correct me if I'm wrong!**). +Thus, the goal was a `pip install keybert` and at most 3 lines of code in usage. + +**NOTE**: If you use MMR to select the candidates instead of simple cosine similarity, +this repo is essentially a simplified implementation of +[EmbedRank](https://github.com/swisscom/ai-research-keyphrase-extraction) +with BERT-embeddings. + + +<a name="gettingstarted"/></a> +## 2. Getting Started +[Back to ToC](#toc) + +<a name="installation"/></a> +### 2.1. Installation +**[PyTorch 1.2.0](https://pytorch.org/get-started/locally/)** or higher is recommended. If the install below gives an +error, please install pytorch first [here](https://pytorch.org/get-started/locally/). + +Installation can be done using [pypi](https://pypi.org/project/bertopic/): + +``` +pip install keybert +``` + +<a name="usage"/></a> +### 2.2. Usage + +The most minimal example can be seen below for the extraction of keywords: +```python +from keybert import KeyBERT + +doc = """ + Supervised learning is the machine learning task of learning a function that + maps an input to an output based on example input-output pairs.[1] It infers a + function from labeled training data consisting of a set of training examples.[2] + In supervised learning, each example is a pair consisting of an input object + (typically a vector) and a desired output value (also called the supervisory signal). + A supervised learning algorithm analyzes the training data and produces an inferred function, + which can be used for mapping new examples. An optimal scenario will allow for the + algorithm to correctly determine the class labels for unseen instances. This requires + the learning algorithm to generalize from the training data to unseen situations in a + 'reasonable' way (see inductive bias). + """ +model = KeyBERT('distilbert-base-nli-mean-tokens') +keywords = model.extract_keywords(doc) +``` + +You can set `keyphrase_length` to set the length of the resulting keyphras: + +```python +>>> model.extract_keywords(doc, keyphrase_length=1, stop_words=None) +['learning', + 'training', + 'algorithm', + 'class', + 'mapping'] +``` + +To extract keyphrases, simply set `keyphrase_length` to 2 or higher depending on the number +of words you would like in the resulting keyphrases: + +```python +>>> model.extract_keywords(doc, keyphrase_length=3, stop_words=None) +['learning algorithm', + 'learning machine', + 'machine learning', + 'supervised learning', + 'learning function'] +``` + +## References +Below, you can find several resources that were used for the creation of KeyBERT +but most importantly, are amazing resources for creating impressive keyword extraction models: + +**Papers**: +* Sharma, P., & Li, Y. (2019). [Self-Supervised Contextual Keyword and Keyphrase Retrieval with Self-Labelling.](https://www.preprints.org/manuscript/201908.0073/download/final_file) + +**Github Repos**: +* https://github.com/thunlp/BERT-KPE +* https://github.com/ibatra/BERT-Keyword-Extractor +* https://github.com/pranav-ust/BERT-keyphrase-extraction +* https://github.com/swisscom/ai-research-keyphrase-extraction + +**MMR**: +The selection of keywords/keyphrases was modelled after: +* https://github.com/swisscom/ai-research-keyphrase-extraction + +**NOTE**: If you find a paper or github repo that has an easy-to-use implementation +of BERT-embeddings for keyword/keyphrase extraction, let me know! I'll make sure to +add it a reference to this repo. + diff --git a/docs/algorithm.md b/docs/algorithm.md deleted file mode 100644 index 47776a0..0000000 --- a/docs/algorithm.md +++ /dev/null @@ -1 +0,0 @@ -# The Algorithm
\ No newline at end of file diff --git a/docs/img/icon.png b/docs/img/icon.png Binary files differdeleted file mode 100644 index 50002cd..0000000 --- a/docs/img/icon.png +++ /dev/null diff --git a/images/icon.png b/images/icon.png Binary files differindex 93765ed..da6c8be 100644 --- a/images/icon.png +++ b/images/icon.png diff --git a/images/logo.png b/images/logo.png Binary files differindex 6b81303..6743d32 100644 --- a/images/logo.png +++ b/images/logo.png diff --git a/keybert/__init__.py b/keybert/__init__.py index e69de29..1d41a2e 100644 --- a/keybert/__init__.py +++ b/keybert/__init__.py @@ -0,0 +1 @@ +from keybert.model import KeyBERT diff --git a/keybert/mmr.py b/keybert/mmr.py new file mode 100644 index 0000000..2f72cff --- /dev/null +++ b/keybert/mmr.py @@ -0,0 +1,92 @@ +# Copyright (c) 2017-present, Swisscom (Schweiz) AG. +# All rights reserved. +# +#Authors: Kamil Bennani-Smires, Yann Savary + + +import numpy as np +from sklearn.metrics.pairwise import cosine_similarity + + +def MMR(doc_embedd, candidates, X, beta, N): + """ + Core method using Maximal Marginal Relevance in charge to return the top-N candidates + :param candidates: list of candidates (string) + :param X: numpy array with the embedding of each candidate in each row + :param beta: hyperparameter beta for MMR (control tradeoff between informativeness and diversity) + :param N: number of candidates to extract + :return: A tuple with 3 elements : + 1)list of the top-N candidates (or less if there are not enough candidates) (list of string) + 2)list of associated relevance scores (list of float) + 3)list containing for each keyphrase a list of alias (list of list of string) + """ + + N = min(N, len(candidates)) + doc_sim = cosine_similarity(X, doc_embedd.reshape(1, -1)) + + doc_sim_norm = doc_sim/np.max(doc_sim) + doc_sim_norm = 0.5 + (doc_sim_norm - np.average(doc_sim_norm)) / np.std(doc_sim_norm) + + sim_between = cosine_similarity(X) + np.fill_diagonal(sim_between, np.NaN) + + sim_between_norm = sim_between/np.nanmax(sim_between, axis=0) + sim_between_norm = \ + 0.5 + (sim_between_norm - np.nanmean(sim_between_norm, axis=0)) / np.nanstd(sim_between_norm, axis=0) + + selected_candidates = [] + unselected_candidates = [c for c in range(len(candidates))] + + j = int(np.argmax(doc_sim)) + selected_candidates.append(j) + unselected_candidates.remove(j) + + for _ in range(N - 1): + selec_array = np.array(selected_candidates) + unselec_array = np.array(unselected_candidates) + + distance_to_doc = doc_sim_norm[unselec_array, :] + dist_between = sim_between_norm[unselec_array][:, selec_array] + if dist_between.ndim == 1: + dist_between = dist_between[:, np.newaxis] + j = np.argmax(beta * distance_to_doc - (1 - beta) * np.max(dist_between, axis=1).reshape(-1, 1)) + item_idx = unselected_candidates[j] + selected_candidates.append(item_idx) + unselected_candidates.remove(item_idx) + + return candidates, selected_candidates + + +def max_normalization(array): + """ + Compute maximum normalization (max is set to 1) of the array + :param array: 1-d array + :return: 1-d array max- normalized : each value is multiplied by 1/max value + """ + return 1/np.max(array) * array.squeeze(axis=1) + + +def get_aliases(kp_sim_between, candidates, threshold): + """ + Find candidates which are very similar to the keyphrases (aliases) + :param kp_sim_between: ndarray of shape (nb_kp , nb candidates) containing the similarity + of each kp with all the candidates. Note that the similarity between the keyphrase and itself should be set to + NaN or 0 + :param candidates: array of candidates (array of string) + :return: list containing for each keyphrase a list that contain candidates which are aliases + (very similar) (list of list of string) + """ + + kp_sim_between = np.nan_to_num(kp_sim_between, 0) + idx_sorted = np.flip(np.argsort(kp_sim_between), 1) + aliases = [] + for kp_idx, item in enumerate(idx_sorted): + alias_for_item = [] + for i in item: + if kp_sim_between[kp_idx, i] >= threshold: + alias_for_item.append(candidates[i]) + else: + break + aliases.append(alias_for_item) + + return aliases diff --git a/keybert/model.py b/keybert/model.py index 1b8f7ac..0910c2b 100644 --- a/keybert/model.py +++ b/keybert/model.py @@ -8,6 +8,25 @@ import warnings class KeyBERT: + """ + A minimal method for keyword extraction with BERT + + The keyword extraction is done by finding the sub-phrases in + a document that are the most similar to the document itself. + + First, document embeddings are extracted with BERT to get a + document-level representation. Then, word embeddings are extracted + for N-gram words/phrases. Finally, we use cosine similarity to find the + words/phrases that are the most similar to the document. + + The most similar words could then be identified as the words that + best describe the entire document. + + Arguments: + model: the name of the model used by sentence-transformer + for a full overview see https://www.sbert.net/docs/pretrained_models.html + + """ def __init__(self, model: str = 'distilbert-base-nli-mean-tokens'): self.model = SentenceTransformer(model) self.doc_embeddings = None @@ -20,10 +39,10 @@ class KeyBERT: min_df: int = 1) -> Union[List[str], List[List[str]]]: """ Extract keywords/keyphrases - NOTE: I would advise you to use - - Single Document: - + NOTE: + I would advise you to iterate over single documents as they + will need the least amount of memory. Even though this is slower, + you are not likely to run into memory errors. Multiple Documents: There is an option to extract keywords for multiple documents @@ -44,7 +63,7 @@ class KeyBERT: if keywords for multiple documents need to be extracted Returns: - keywords: The top n keywords for a document + keywords: the top n keywords for a document """ @@ -29,7 +29,7 @@ theme: feature: tabs: true palette: - primary: indigo + primary: black accent: blue markdown_extensions: - codehilite diff --git a/tests/conftest.py b/tests/conftest.py index b3bbd19..50fc046 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,15 +1,8 @@ -# from bertopic import BERTopic -# import pytest -# -# -# @pytest.fixture(scope="module") -# def base_bertopic(): -# model = BERTopic(bert_model='distilbert-base-nli-mean-tokens', -# top_n_words=20, -# nr_topics=None, -# n_gram_range=(1, 1), -# min_topic_size=30, -# n_neighbors=15, -# n_components=5, -# verbose=False) -# return model
\ No newline at end of file +from keybert import KeyBERT +import pytest + + +@pytest.fixture(scope="module") +def base_keybert(): + model = KeyBERT(model = 'distilbert-base-nli-mean-tokens') + return model diff --git a/tests/test_model.py b/tests/test_model.py index 7f5c52a..2535e4c 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,20 +1,61 @@ -# import pytest -# import numpy as np -# import pandas as pd -# from unittest import mock -# -# from sklearn.datasets import fetch_20newsgroups, make_blobs -# from keybert import KeyBERT -# -# newsgroup_docs = fetch_20newsgroups(subset='all')['data'][:1000] -# -# @mock.patch("bertopic.model.BERTopic._extract_embeddings") -# def test_fit_transform(embeddings, base_bertopic): -# """ Test whether predictions are correctly made """ -# blobs, _ = make_blobs(n_samples=len(newsgroup_docs), centers=5, n_features=768, random_state=42) -# embeddings.return_value = blobs -# predictions = base_bertopic.fit_transform(newsgroup_docs) -# -# assert isinstance(predictions, list) -# assert len(predictions) == len(newsgroup_docs) -# assert not set(predictions).difference(set(base_bertopic.get_topics().keys())) +import pytest +from .utils import get_test_data + +doc_one, doc_two = get_test_data() + + +@pytest.mark.parametrize("keyphrase_length", [i+1 for i in range(5)]) +def test_single_doc(keyphrase_length, base_keybert): + """ Test whether the keywords are correctly extracted """ + top_n = 5 + keywords = base_keybert.extract_keywords(doc_one, keyphrase_length=keyphrase_length, min_df=1, top_n=top_n) + assert isinstance(keywords, list) + assert isinstance(keywords[0], str) + assert len(keywords) == top_n + for keyword in keywords: + assert len(keyword.split(" ")) == keyphrase_length + + +@pytest.mark.parametrize("keyphrase_length", [i+1 for i in range(5)]) +def test_extract_keywords_single_doc(keyphrase_length, base_keybert): + """ Test extraction of protected single document method """ + top_n = 5 + keywords = base_keybert._extract_keywords_single_doc(doc_one, top_n=top_n, keyphrase_length=keyphrase_length) + assert isinstance(keywords, list) + assert isinstance(keywords[0], str) + assert len(keywords) == top_n + for keyword in keywords: + assert len(keyword.split(" ")) == keyphrase_length + + +@pytest.mark.parametrize("keyphrase_length", [i+1 for i in range(5)]) +def test_extract_keywords_multiple_docs(keyphrase_length, base_keybert): + """ Test extractino of protected multiple document method""" + top_n = 5 + keywords_list = base_keybert._extract_keywords_multiple_docs([doc_one, doc_two], + top_n=top_n, + keyphrase_length=keyphrase_length) + assert isinstance(keywords_list, list) + assert isinstance(keywords_list[0], list) + assert len(keywords_list) == 2 + + for keywords in keywords_list: + assert len(keywords) == top_n + + for keyword in keywords: + assert len(keyword.split(" ")) == keyphrase_length + + +def test_error(base_keybert): + """ Empty doc should raise a ValueError """ + with pytest.raises(AttributeError): + doc = [] + base_keybert._extract_keywords_single_doc(doc) + + +def test_empty_doc(base_keybert): + """ Test empty document """ + doc = "" + result = base_keybert._extract_keywords_single_doc(doc) + + assert result == [] diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..43bc904 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,24 @@ +def get_test_data(): + doc_one = "\n\nI am sure some bashers of Pens fans are pretty confused about the lack\nof " \ + "any kind of posts about the recent Pens massacre of the Devils. Actually,\nI am " \ + "bit puzzled too and a bit relieved. However, I am going to put an end\nto non-PIttsburghers' " \ + "relief with a bit of praise for the Pens. Man, they\nare killing those Devils worse than I thought. " \ + "Jagr just showed you why\nhe is much better than his regular season stats. " \ + "He is also a lot\nfo fun to watch in the playoffs. Bowman should let JAgr have " \ + "a lot of\nfun in the next couple of games since the Pens are going to beat the " \ + "pulp out of Jersey anyway. I was very disappointed not to see the Islanders lose " \ + "the final\nregular season game. PENS RULE!!!\n\n" + + doc_two = "\n[stuff deleted]\n\nOk, here's the solution to your problem. " \ + "Move to Canada. Yesterday I was able\nto watch FOUR games...the NJ-PITT " \ + "at 1:00 on ABC, LA-CAL at 3:00 (CBC), \nBUFF-BOS at 7:00 (TSN and FOX), " \ + "and MON-QUE at 7:30 (CBC). I think that if\neach series goes its max I " \ + "could be watching hockey playoffs for 40-some odd\nconsecutive nights " \ + "(I haven't counted so that's a pure guess).\n\nI have two tv's in my house, " \ + "and I set them up side-by-side to watch MON-QUE\nand keep an eye on " \ + "BOS-BUFF at the same time. I did the same for the two\nafternoon games." \ + "\n\nBtw, those ABC commentaters were great! I was quite impressed; they " \ + "seemed\nto know that their audience wasn't likely to be well-schooled in " \ + "hockey lore\nand they did an excellent job. They were quite impartial also, IMO.\n\n" + + return doc_one, doc_two |