diff options
author | John Bauer <horatio@gmail.com> | 2022-07-14 01:40:39 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-07-14 01:43:35 +0300 |
commit | 8e0f2fe37838b198c6d94b03a7c1b8a33e0bfaf4 (patch) | |
tree | 1fa5bfd6a9c51d7de591654bd720816d5ae527a0 | |
parent | 02bfa5d0804741a8382c6fb42cb124e52b51be8a (diff) |
Mask illegal langauges by setting them to -ninf. 0 means that illegal languages can be chosen if all legal languages were negativeninf_langid
Addresses #1076
-rw-r--r-- | stanza/models/langid/model.py | 12 | ||||
-rw-r--r-- | stanza/tests/langid/test_langid.py | 17 |
2 files changed, 25 insertions, 4 deletions
diff --git a/stanza/models/langid/model.py b/stanza/models/langid/model.py index 082738ea..f75426cd 100644 --- a/stanza/models/langid/model.py +++ b/stanza/models/langid/model.py @@ -58,11 +58,15 @@ class LangIDBiLSTM(nn.Module): def build_lang_mask(self, use_gpu=None): """ Build language mask if a lang subset is specified (e.g. ["en", "fr"]) + + The mask will be added to the results to set the prediction scores of illegal languages to -inf """ device = torch.device("cuda") if use_gpu else None - lang_mask_list = [int(lang in self.lang_subset) for lang in self.idx_to_tag] if self.lang_subset else \ - [1 for lang in self.idx_to_tag] - self.lang_mask = torch.tensor(lang_mask_list, device=device, dtype=torch.float) + if self.lang_subset: + lang_mask_list = [0.0 if lang in self.lang_subset else -float('inf') for lang in self.idx_to_tag] + self.lang_mask = torch.tensor(lang_mask_list, device=device, dtype=torch.float) + else: + self.lang_mask = torch.zeros(len(self.idx_to_tag), device=device, dtype=torch.float) def loss(self, Y_hat, Y): return self.loss_train(Y_hat, Y) @@ -87,7 +91,7 @@ class LangIDBiLSTM(nn.Module): if self.lang_subset: prediction_batch_size = prediction_probs.size()[0] batch_mask = torch.stack([self.lang_mask for _ in range(prediction_batch_size)]) - prediction_probs = prediction_probs * batch_mask + prediction_probs = prediction_probs + batch_mask return torch.argmax(prediction_probs, dim=1) def save(self, path): diff --git a/stanza/tests/langid/test_langid.py b/stanza/tests/langid/test_langid.py index 5dd36125..7dba9e40 100644 --- a/stanza/tests/langid/test_langid.py +++ b/stanza/tests/langid/test_langid.py @@ -578,6 +578,23 @@ def test_lang_subset(): nlp(docs) assert [doc.lang for doc in docs] == ["en", "en"] +def test_lang_subset_unlikely_language(): + """ + Test that the language subset masking chooses a legal language, even if all legal languages are supa unlikely + """ + sentences = ["你好" * 200] + docs = [Document([], text=text) for text in sentences] + nlp = Pipeline(dir=TEST_MODELS_DIR, lang="multilingual", processors="langid", langid_lang_subset=["en"]) + nlp(docs) + assert [doc.lang for doc in docs] == ["en"] + + processor = nlp.processors['langid'] + model = processor._model + text_tensor = processor._text_to_tensor(sentences) + en_idx = model.tag_to_idx['en'] + predictions = model(text_tensor) + assert predictions[0, en_idx] < 0, "If this test fails, then regardless of how unlikely it was, the model is predicting the input string is possibly English. Update the test by picking a different combination of languages & input" + def test_multilingual_pipeline(): """ Basic test of multilingual pipeline |