Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-07-14 01:40:39 +0300
committerJohn Bauer <horatio@gmail.com>2022-07-14 01:43:35 +0300
commit8e0f2fe37838b198c6d94b03a7c1b8a33e0bfaf4 (patch)
tree1fa5bfd6a9c51d7de591654bd720816d5ae527a0
parent02bfa5d0804741a8382c6fb42cb124e52b51be8a (diff)
Mask illegal langauges by setting them to -ninf. 0 means that illegal languages can be chosen if all legal languages were negativeninf_langid
Addresses #1076
-rw-r--r--stanza/models/langid/model.py12
-rw-r--r--stanza/tests/langid/test_langid.py17
2 files changed, 25 insertions, 4 deletions
diff --git a/stanza/models/langid/model.py b/stanza/models/langid/model.py
index 082738ea..f75426cd 100644
--- a/stanza/models/langid/model.py
+++ b/stanza/models/langid/model.py
@@ -58,11 +58,15 @@ class LangIDBiLSTM(nn.Module):
def build_lang_mask(self, use_gpu=None):
"""
Build language mask if a lang subset is specified (e.g. ["en", "fr"])
+
+ The mask will be added to the results to set the prediction scores of illegal languages to -inf
"""
device = torch.device("cuda") if use_gpu else None
- lang_mask_list = [int(lang in self.lang_subset) for lang in self.idx_to_tag] if self.lang_subset else \
- [1 for lang in self.idx_to_tag]
- self.lang_mask = torch.tensor(lang_mask_list, device=device, dtype=torch.float)
+ if self.lang_subset:
+ lang_mask_list = [0.0 if lang in self.lang_subset else -float('inf') for lang in self.idx_to_tag]
+ self.lang_mask = torch.tensor(lang_mask_list, device=device, dtype=torch.float)
+ else:
+ self.lang_mask = torch.zeros(len(self.idx_to_tag), device=device, dtype=torch.float)
def loss(self, Y_hat, Y):
return self.loss_train(Y_hat, Y)
@@ -87,7 +91,7 @@ class LangIDBiLSTM(nn.Module):
if self.lang_subset:
prediction_batch_size = prediction_probs.size()[0]
batch_mask = torch.stack([self.lang_mask for _ in range(prediction_batch_size)])
- prediction_probs = prediction_probs * batch_mask
+ prediction_probs = prediction_probs + batch_mask
return torch.argmax(prediction_probs, dim=1)
def save(self, path):
diff --git a/stanza/tests/langid/test_langid.py b/stanza/tests/langid/test_langid.py
index 5dd36125..7dba9e40 100644
--- a/stanza/tests/langid/test_langid.py
+++ b/stanza/tests/langid/test_langid.py
@@ -578,6 +578,23 @@ def test_lang_subset():
nlp(docs)
assert [doc.lang for doc in docs] == ["en", "en"]
+def test_lang_subset_unlikely_language():
+ """
+ Test that the language subset masking chooses a legal language, even if all legal languages are supa unlikely
+ """
+ sentences = ["你好" * 200]
+ docs = [Document([], text=text) for text in sentences]
+ nlp = Pipeline(dir=TEST_MODELS_DIR, lang="multilingual", processors="langid", langid_lang_subset=["en"])
+ nlp(docs)
+ assert [doc.lang for doc in docs] == ["en"]
+
+ processor = nlp.processors['langid']
+ model = processor._model
+ text_tensor = processor._text_to_tensor(sentences)
+ en_idx = model.tag_to_idx['en']
+ predictions = model(text_tensor)
+ assert predictions[0, en_idx] < 0, "If this test fails, then regardless of how unlikely it was, the model is predicting the input string is possibly English. Update the test by picking a different combination of languages & input"
+
def test_multilingual_pipeline():
"""
Basic test of multilingual pipeline