Welcome to mirror list, hosted at ThFree Co, Russian Federation.

test_model.py « tests - github.com/MaartenGr/KeyBERT.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 7f5c52af8704fe2ae619ef67bfba9719c8b25e4e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# import pytest
# import numpy as np
# import pandas as pd
# from unittest import mock
#
# from sklearn.datasets import fetch_20newsgroups, make_blobs
# from keybert import KeyBERT
#
# newsgroup_docs = fetch_20newsgroups(subset='all')['data'][:1000]
#
# @mock.patch("bertopic.model.BERTopic._extract_embeddings")
# def test_fit_transform(embeddings, base_bertopic):
#     """ Test whether predictions are correctly made """
#     blobs, _ = make_blobs(n_samples=len(newsgroup_docs), centers=5, n_features=768, random_state=42)
#     embeddings.return_value = blobs
#     predictions = base_bertopic.fit_transform(newsgroup_docs)
#
#     assert isinstance(predictions, list)
#     assert len(predictions) == len(newsgroup_docs)
#     assert not set(predictions).difference(set(base_bertopic.get_topics().keys()))