diff options
author | Linxiao ZENG <linxiao.zeng@gmail.com> | 2021-04-26 11:36:59 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-04-26 11:36:59 +0300 |
commit | 1e8173cbfaba6f1248b6df53f5fd00892ce2e69c (patch) | |
tree | d6debcf84138e126d22a7cdebd5d8d1ad810522c | |
parent | 5c2eadada650b15f1bf46ec86f6e7008962efedd (diff) |
Add some transforms unittest (#2049)
-rw-r--r-- | data/sample.bpe | 1001 | ||||
-rw-r--r-- | data/sample.sp.model | bin | 0 -> 252917 bytes | |||
-rw-r--r-- | onmt/tests/test_transform.py | 465 | ||||
-rw-r--r-- | onmt/transforms/bart.py | 38 | ||||
-rw-r--r-- | onmt/transforms/sampling.py | 39 | ||||
-rw-r--r-- | onmt/transforms/tokenize.py | 27 |
6 files changed, 1517 insertions, 53 deletions
diff --git a/data/sample.bpe b/data/sample.bpe new file mode 100644 index 00000000..59e66f35 --- /dev/null +++ b/data/sample.bpe @@ -0,0 +1,1001 @@ +#version: 0.2 +t h +i n +th e</w> +a n +r e +t i +e r +e n +o n +a r +o u +o f</w> +o n</w> +o r +an d</w> +t o</w> +t e +r o +in g</w> +i s</w> +i n</w> +a l +i t +e s</w> +i s +e d</w> +e r</w> +a t</w> +o r</w> +a l</w> +o m +i c +s t +a ti +s i +a c +a n</w> +l y</w> +e c +a s</w> +i l +p o +a t +e s +T h +e l +v e</w> +m en +u r +p ro +a s +th at</w> +f or</w> +l i +t s</w> +l e</w> +w h +l e +t r +u s +n o +e n</w> +c on +d e +ati on</w> +s u +h a +ar e</w> +l o +a b +a g +a d +te d</w> +it h</w> +q u +u n +Th e</w> +o p +c om +w ith</w> +b e</w> +d i +po s +b e +a m +ti on</w> +r i +c h +men t</w> +ou r</w> +i t</w> +re s +c e</w> +l d</w> +ou n +p or +e x +p e +on s</w> +th is</w> +u l +f or +v er +n e +o l +th e +t a +w e</w> +c i +s h +o u</w> +o t +a y</w> +ha ve</w> +i m +er s</w> +it y</w> +er e</w> +no t</w> +f f +v i +g h +a p +b y</w> +en t</w> +ou ld</w> +y ou</w> +d u +e m +al l</w> +il l</w> +b u +p l +u ro +at e</w> +t er +m o +E uro +w ill</w> +& a +&a pos +s e +ro m</w> +i r +ic h</w> +a r</w> +f rom</w> +w or +k e</w> +te r</w> +s t</w> +wh ich</w> +p ar +m a +s o</w> +r a +s e</w> +p er +h as</w> +m e +th er</w> +ou t</w> +o w +d s</w> +c o +re n +ab le</w> +on e</w> +c h</w> +ur e</w> +c an</w> +si on</w> +m an +es s</w> +s p +e t</w> +i es</w> +e v +ou r +o o +d is +on al</w> +C om +an t</w> +si d +te s</w> +al ly</w> +re c +&apos ; +m is +w as</w> +ti c +Euro pe +g r +Europe an</w> +i r</w> +por t</w> +h o +m e</w> +us t</w> +d ing</w> +du c +or e</w> +a ted</w> +c l +W e</w> +ag e</w> +v e +y our</w> +n i +in g +al so</w> +ti ons</w> +a y +t o +u s</w> +gh t</w> +c oun +g e</w> +it s</w> +t ing</w> +b er</w> +t u +f in +qu ot +f e +t w +i c</w> +r u +s o +' s</w> +com p +u m +h i +& quot +re d</w> +Th is</w> +k ing</w> +c re +men ts</w> +d e</w> +" ;</w> +m ore</w> +ati ons</w> +s er +be en</w> +p res +es s +s te +I n</w> +t h</w> +an ce</w> +w ould</w> +ou s</w> +re g +Com mis +for m +ver y</w> +e ar +M r</w> +i ti +d ed</w> +sh ould</w> +the ir</w> +I t</w> +us e</w> +ar y</w> +bu t</w> +d er</w> +v er</w> +d o +f i +por t +m in +0 0 +c e +g u +g h</w> +en t +f ac +a v +es e</w> +re sid +ti ve</w> +a u +p r +w i +ec t</w> +f u +t y</w> +si b +g o +es t</w> +om e</w> +ic e</w> +o ther</w> +o ff +k s</w> +t en +no w</w> +coun tr +d er +as e</w> +l a +P resid +in e</w> +p re +s ti +re s</w> +v el +con t +p u +P ar +ab out</w> +as t</w> +w el +d o</w> +m ar +d ay</w> +Commis sion</w> +h e +a d</w> +su p +th ere</w> +cl u +the y</w> +ne w</w> +g e +li ke</w> +al l +ec on +en ce</w> +pro duc +s el +com m +k e +ati onal</w> +a in +ar d</w> +a m</w> +in ter +p e</w> +m ust</w> +an y</w> +ec i +ti me</w> +Presid ent</w> +c er +r ou +wor k</w> +f ir +a i +m s</w> +an s +on ly</w> +I n +U ni +de vel +v is +l ic +o b +te l</w> +a re +it e</w> +v es</w> +em ber</w> +g i +w e +re e</w> +2 00 +s y +th ese</w> +m on +ad e</w> +ci al</w> +t or +ac e</w> +c u +oun d</w> +&apos ;</w> +al s</w> +si g +ar i +res p +t e</w> +im port +for e</w> +lo c +devel op +am e</w> +u p</w> +li a +man y</w> +wh o</w> +oo d</w> +l l</w> +r on +w ay</w> +Euro pe</w> +an d +b ec +m y</w> +w ere</w> +a f +ren t</w> +lo w +wh at</w> +e st +ser v +di ff +ow n</w> +p ri +qu i +su b +ic al</w> +ing s</w> +ac c +an s</w> +at ing</w> +c r +pe op +ac h</w> +n o</w> +su ch</w> +t ur +wi th +il ity</w> +n ational</w> +t ure</w> +S ta +re e +a te +l is +ne ed</w> +ay s</w> +mo st</w> +en ti +as s +b i +lia ment</w> +sib le</w> +c es</w> +i on</w> +peop le</w> +ch an +op er +t ly</w> +el i +g y</w> +iti es</w> +s ec +tic al</w> +en er +ol u +vi e +Par liament</w> +am p +i f</w> +re port</w> +b o +ev er</w> +in to</w> +st ru +. . +p le +sy ste +th an</w> +u e</w> +p h +r an +fir st</w> +in form +Sta tes</w> +countr ies</w> +en tr +f ul +v ed</w> +v en +Uni on</w> +il l +E U</w> +econ om +p en +wh en</w> +c le +fu l</w> +s ome</w> +tr ans +men t +tr i +g re +import ant</w> +in te +wel l</w> +di rec +W h +a il +C oun +o c +c ri +m i +sup port</w> +th o +re st +ta ke</w> +ci l</w> +in clu +pro b +th ing</w> +M ember</w> +ati ve</w> +h e</w> +le g +ma ke</w> +ta in +1 9 +a st +h el +A n +C on +po in +Coun cil</w> +l ar +e ff +ic i +o ver</w> +po lic +e en</w> +m p +is e</w> +c ol +c our +on g</w> +en ts</w> +g et</w> +h is</w> +il e</w> +sp eci +per i +tw een</w> +ic es</w> +j o +re l +H o +be tween</w> +n um +i g +is h</w> +j ust</w> +po si +prob le +pro vi +the m</w> +ti es</w> +a in</w> +el l +n ing</w> +resp on +u r</w> +C h +) .</w> +or s</w> +p s</w> +tw o</w> +y ear +f un +m u +pro pos +pu b +po li +I f</w> +al ity</w> +com e</w> +inform ation</w> +su l +sti tu +a use</w> +m ade</w> +n ec +al i +n ed</w> +c ur +d ec +diff e +ce p +is su +year s</w> +ac h +ap pro +em b +ow ever</w> +ri ght</w> +c an +en d</w> +it u +m at +re ad +te c +v ing</w> +f ic +t re +us ed</w> +.. .</w> +ho tel</w> +wh ere</w> +c ar +gh ts</w> +0 0</w> +ac k</w> +h ere</w> +ot h</w> +r ic +te ch +de b +e l</w> +o f +o m</w> +y e +B u +c or +pos sible</w> +th rou +j ec +ou gh</w> +pro c +stru c +e i +e t +tic ul +un ity</w> +con cer +li c</w> +or der</w> +ver n +l in +polic y</w> +sh i +u m</w> +per s +ti c</w> +a k +th en</w> +ac t</w> +k et</w> +ti on +are a</w> +f t</w> +m ay</w> +qu es +s ur +syste m</w> +us in +a tes</w> +g ood</w> +ta in</w> +as ed</w> +is h +p p +s itu +c lo +par t</w> +st an +ac ti +c y</w> +e d +in cre +m il +par ticul +e ve</w> +re fore</w> +b r +in ce</w> +om s</w> +or y</w> +poli tical</w> +ac k +ag es</w> +develop ment</w> +ha d</w> +F or</w> +ar ds</w> +bec ause</w> +e y</w> +num ber</w> +s c +tho se</w> +ur ing</w> +Com m +hi gh +no w +s ure</w> +wor ld</w> +R e +af ter</w> +for m</w> +s m +un der +Th ere</w> +be ing</w> +comp le +l es</w> +pl ace</w> +r y</w> +cont in +is t +j ect</w> +si ve</w> +A l +A s</w> +ai d</w> +bu il +ho w</w> +lo w</w> +ma in +om e +ot e</w> +si on +Y ou</w> +b eli +c entr +ev en</w> +in v +is t</w> +v en</w> +vie w</w> +ye ar</w> +f o +f ol +res s</w> +& # +b usin +i d +ro w +c ould</w> +ch ar +diffe rent</w> +econom ic</w> +is m</w> +ou s +se e</w> +tri bu +countr y</w> +d re +e f +en g +k ed</w> +u d +w ant</w> +an ts</w> +p lo +u res</w> +with in</w> +y ing</w> +in k</w> +A r +c ed</w> +d es</w> +do es</w> +e ment</w> +fin d</w> +or d +S t +a ir +ag ree +b oth</w> +en s</w> +ou se</w> +ou t +pl ic +su m +ar ly</w> +f l +k now</w> +lo o +pro tec +pub lic</w> +s pe +v al +a ter</w> +i de +d at +en d +mu ch</w> +s ed</w> +throu gh</w> +ati c</w> +fin an +g an +mar ket</w> +se t</w> +v ari +Ho tel</w> +c ul +f am +f ur +i l</w> +produc ts</w> +sel f</w> +sti ll</w> +an k</w> +cle ar</w> +r ap +st ar +ve l</w> +go vern +men d +t or</w> +un der</w> +i a</w> +lo g +re qui +re sul +s a +te x +ti onal</w> +to day</w> +v o +f ree</w> +gre at</w> +l ine</w> +olu tion</w> +qu ality</w> +read y</w> +ri ghts</w> +s ome +st ra +w s</w> +ag ain +c ap +ci p +de sig +l and</w> +pr in +ro om</w> +again st</w> +c ase</w> +con di +ne e +si ons</w> +te n</w> +Th at</w> +av ail +beli eve</w> +e ver +ex peri +l an +oper ation</w> +cont ro +g ener +i z +ta k +te e</w> +th er +the refore</w> +Bu t</w> +G er +ab ility</w> +be fore</w> +gr ou +l ed</w> +off er</w> +s ame</w> +th or +ch il +chan ge</w> +ful ly</w> +th ink</w> +cu st +gr am +sid e</w> +O ur</w> +busin ess</w> +c c +l ast</w> +m an</w> +me di +off ers</w> +serv ices</w> +vi ron +hel p</w> +n er</w> +o w</w> +par t +proc ess</w> +s en +t able</w> +ro s +so cial</w> +um an</w> +w ays</w> +b er +b le</w> +d ic +k now +lo b +on g +ro oms</w> +t al +t ro +l ess</w> +ma in</w> +op e</w> +deb ate</w> +du st +el e +el ec +in d</w> +inter national</w> +ment al</w> +nec ess +po w +poin t</w> +s ite</w> +so ci +P ro +bu d +c us +can not</w> +d ra +eff ec +im pro +tur al</w> +en viron +m ul +n at +s ince</w> +shi p</w> +b re +cour se</w> +fe ren +i ons</w> +me as +n s</w> +reg ul +s ay</w> +s k +sm all</w> +ac cep +li f +m es</w> +me ans</w> +or gan +posi tion</w> +t le</w> +to o</w> +ve ly</w> +with out</w> +Commis sion +c all +c ell +di vi +dis cus +p ur +pro ce +th ough</w> +A f +M ar +ac coun +al ready</w> +ap p +pl ay +r s</w> +&# 9 +1 0</w> +A u +In ter +d uring</w> +dat a</w> +en cy</w> +fu ture</w> +in stitu +oo k</w> +r is +r it +E n +H owever</w> +ad v +cust om +d es +ex p +finan cial</w> +low ing</w> +situ ation</w> +S u +a ff +a ir</w> +avail able</w> +g es</w> +h uman</w> +it al</w> +r e</w> +st and +wh e +2 0 +e p +em plo +ent ly</w> +in dust +issu e</w> +o p</w> diff --git a/data/sample.sp.model b/data/sample.sp.model Binary files differnew file mode 100644 index 00000000..2e44c6ea --- /dev/null +++ b/data/sample.sp.model diff --git a/onmt/tests/test_transform.py b/onmt/tests/test_transform.py new file mode 100644 index 00000000..f6fc3231 --- /dev/null +++ b/onmt/tests/test_transform.py @@ -0,0 +1,465 @@ +"""Here come the tests for implemented transform.""" +import unittest + +import copy +import yaml +import math +from argparse import Namespace +from onmt.transforms import get_transforms_cls, get_specials, make_transforms +from onmt.transforms.bart import BARTNoising + + +class TestTransform(unittest.TestCase): + def test_transform_register(self): + builtin_transform = [ + "filtertoolong", + "prefix", + "sentencepiece", + "bpe", + "onmt_tokenize", + "bart", + "switchout", + "tokendrop", + "tokenmask", + ] + get_transforms_cls(builtin_transform) + + def test_vocab_required_transform(self): + transforms_cls = get_transforms_cls(["bart", "switchout"]) + opt = Namespace(seed=-1, switchout_temperature=1.0) + # transforms that require vocab will not create if not provide vocab + transforms = make_transforms(opt, transforms_cls, fields=None) + self.assertEqual(len(transforms), 0) + with self.assertRaises(ValueError): + transforms_cls["switchout"](opt).warm_up(vocabs=None) + transforms_cls["bart"](opt).warm_up(vocabs=None) + + def test_transform_specials(self): + transforms_cls = get_transforms_cls(["prefix"]) + corpora = yaml.safe_load(""" + trainset: + path_src: data/src-train.txt + path_tgt: data/tgt-train.txt + transforms: ["prefix"] + weight: 1 + src_prefix: "⦅_pf_src⦆" + tgt_prefix: "⦅_pf_tgt⦆" + """) + opt = Namespace(data=corpora) + specials = get_specials(opt, transforms_cls) + specials_expected = {"src": {"⦅_pf_src⦆"}, "tgt": {"⦅_pf_tgt⦆"}} + self.assertEqual(specials, specials_expected) + + +class TestMiscTransform(unittest.TestCase): + def test_prefix(self): + prefix_cls = get_transforms_cls(["prefix"])["prefix"] + corpora = yaml.safe_load(""" + trainset: + path_src: data/src-train.txt + path_tgt: data/tgt-train.txt + transforms: [prefix] + weight: 1 + src_prefix: "⦅_pf_src⦆" + tgt_prefix: "⦅_pf_tgt⦆" + """) + opt = Namespace(data=corpora, seed=-1) + prefix_transform = prefix_cls(opt) + prefix_transform.warm_up() + self.assertIn("trainset", prefix_transform.prefix_dict) + + ex_in = { + "src": ["Hello", "world", "."], + "tgt": ["Bonjour", "le", "monde", "."], + } + with self.assertRaises(ValueError): + prefix_transform.apply(ex_in) + prefix_transform.apply(ex_in, corpus_name="validset") + ex_out = prefix_transform.apply(ex_in, corpus_name="trainset") + self.assertEqual(ex_out["src"][0], "⦅_pf_src⦆") + self.assertEqual(ex_out["tgt"][0], "⦅_pf_tgt⦆") + + def test_filter_too_long(self): + filter_cls = get_transforms_cls(["filtertoolong"])["filtertoolong"] + opt = Namespace(src_seq_length=100, tgt_seq_length=100) + filter_transform = filter_cls(opt) + # filter_transform.warm_up() + ex_in = { + "src": ["Hello", "world", "."], + "tgt": ["Bonjour", "le", "monde", "."], + } + ex_out = filter_transform.apply(ex_in) + self.assertIs(ex_out, ex_in) + filter_transform.tgt_seq_length = 2 + ex_out = filter_transform.apply(ex_in) + self.assertIsNone(ex_out) + + +class TestSubwordTransform(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.base_opts = { + "seed": 3431, + "share_vocab": False, + "src_subword_model": "data/sample.bpe", + "tgt_subword_model": "data/sample.bpe", + "src_subword_nbest": 1, + "tgt_subword_nbest": 1, + "src_subword_alpha": 0.0, + "tgt_subword_alpha": 0.0, + "src_subword_vocab": "", + "tgt_subword_vocab": "", + "src_vocab_threshold": 0, + "tgt_vocab_threshold": 0, + } + + def test_bpe(self): + bpe_cls = get_transforms_cls(["bpe"])["bpe"] + opt = Namespace(**self.base_opts) + bpe_cls._validate_options(opt) + bpe_transform = bpe_cls(opt) + bpe_transform.warm_up() + ex = { + "src": ["Hello", "world", "."], + "tgt": ["Bonjour", "le", "monde", "."], + } + bpe_transform.apply(ex, is_train=True) + ex_gold = { + "src": ["H@@", "ell@@", "o", "world", "."], + "tgt": ["B@@", "on@@", "j@@", "our", "le", "mon@@", "de", "."], + } + self.assertEqual(ex, ex_gold) + # test BPE-dropout: + bpe_transform.dropout["src"] = 1.0 + tokens = ["Another", "world", "."] + gold_bpe = ["A@@", "no@@", "ther", "world", "."] + gold_dropout = [ + "A@@", "n@@", "o@@", "t@@", "h@@", "e@@", "r", + "w@@", "o@@", "r@@", "l@@", "d", ".", + ] + # 1. disable bpe dropout for not training example + after_bpe = bpe_transform._tokenize(tokens, is_train=False) + self.assertEqual(after_bpe, gold_bpe) + # 2. enable bpe dropout for training example + after_bpe = bpe_transform._tokenize(tokens, is_train=True) + self.assertEqual(after_bpe, gold_dropout) + # 3. (NOTE) disable dropout won't take effect if already seen + # this is caused by the cache mechanism in bpe: + # return cached subword if the original token is seen when no dropout + after_bpe2 = bpe_transform._tokenize(tokens, is_train=False) + self.assertEqual(after_bpe2, gold_dropout) + + def test_sentencepiece(self): + sp_cls = get_transforms_cls(["sentencepiece"])["sentencepiece"] + base_opt = copy.copy(self.base_opts) + base_opt["src_subword_model"] = "data/sample.sp.model" + base_opt["tgt_subword_model"] = "data/sample.sp.model" + opt = Namespace(**base_opt) + sp_cls._validate_options(opt) + sp_transform = sp_cls(opt) + sp_transform.warm_up() + ex = { + "src": ["Hello", "world", "."], + "tgt": ["Bonjour", "le", "monde", "."], + } + sp_transform.apply(ex, is_train=True) + ex_gold = { + "src": ["▁H", "el", "lo", "▁world", "▁."], + "tgt": ["▁B", "on", "j", "o", "ur", "▁le", "▁m", "on", "de", "▁."], + } + self.assertEqual(ex, ex_gold) + # test SP regularization: + sp_transform.src_subword_nbest = 4 + tokens = ["Another", "world", "."] + gold_sp = ["▁An", "other", "▁world", "▁."] + # 1. enable regularization for training example + after_sp = sp_transform._tokenize(tokens, is_train=True) + self.assertEqual(after_sp, ["▁An", "o", "ther", "▁world", "▁."]) + # 2. disable regularization for not training example + after_sp = sp_transform._tokenize(tokens, is_train=False) + self.assertEqual(after_sp, gold_sp) + + def test_pyonmttok_bpe(self): + onmttok_cls = get_transforms_cls(["onmt_tokenize"])["onmt_tokenize"] + base_opt = copy.copy(self.base_opts) + base_opt["src_subword_type"] = "bpe" + base_opt["tgt_subword_type"] = "bpe" + onmt_args = "{'mode': 'space', 'joiner_annotate': True}" + base_opt["src_onmttok_kwargs"] = onmt_args + base_opt["tgt_onmttok_kwargs"] = onmt_args + opt = Namespace(**base_opt) + onmttok_cls._validate_options(opt) + onmttok_transform = onmttok_cls(opt) + onmttok_transform.warm_up() + ex = { + "src": ["Hello", "world", "."], + "tgt": ["Bonjour", "le", "monde", "."], + } + onmttok_transform.apply(ex, is_train=True) + ex_gold = { + "src": ["H■", "ell■", "o", "world", "."], + "tgt": ["B■", "on■", "j■", "our", "le", "mon■", "de", "."], + } + self.assertEqual(ex, ex_gold) + + def test_pyonmttok_sp(self): + onmttok_cls = get_transforms_cls(["onmt_tokenize"])["onmt_tokenize"] + base_opt = copy.copy(self.base_opts) + base_opt["src_subword_type"] = "sentencepiece" + base_opt["tgt_subword_type"] = "sentencepiece" + base_opt["src_subword_model"] = "data/sample.sp.model" + base_opt["tgt_subword_model"] = "data/sample.sp.model" + onmt_args = "{'mode': 'none', 'spacer_annotate': True}" + base_opt["src_onmttok_kwargs"] = onmt_args + base_opt["tgt_onmttok_kwargs"] = onmt_args + opt = Namespace(**base_opt) + onmttok_cls._validate_options(opt) + onmttok_transform = onmttok_cls(opt) + onmttok_transform.warm_up() + ex = { + "src": ["Hello", "world", "."], + "tgt": ["Bonjour", "le", "monde", "."], + } + onmttok_transform.apply(ex, is_train=True) + ex_gold = { + "src": ["▁H", "el", "lo", "▁world", "▁."], + "tgt": ["▁B", "on", "j", "o", "ur", "▁le", "▁m", "on", "de", "▁."], + } + self.assertEqual(ex, ex_gold) + + +class TestSamplingTransform(unittest.TestCase): + def test_tokendrop(self): + tokendrop_cls = get_transforms_cls(["tokendrop"])["tokendrop"] + opt = Namespace(seed=3434, tokendrop_temperature=0.1) + tokendrop_transform = tokendrop_cls(opt) + tokendrop_transform.warm_up() + ex = { + "src": ["Hello", ",", "world", "."], + "tgt": ["Bonjour", "le", "monde", "."], + } + # Not apply token drop for not training example + ex_after = tokendrop_transform.apply(copy.deepcopy(ex), is_train=False) + self.assertEqual(ex_after, ex) + # apply token drop for training example + ex_after = tokendrop_transform.apply(copy.deepcopy(ex), is_train=True) + self.assertNotEqual(ex_after, ex) + + def test_tokenmask(self): + tokenmask_cls = get_transforms_cls(["tokenmask"])["tokenmask"] + opt = Namespace(seed=3434, tokenmask_temperature=0.1) + tokenmask_transform = tokenmask_cls(opt) + tokenmask_transform.warm_up() + ex = { + "src": ["Hello", ",", "world", "."], + "tgt": ["Bonjour", "le", "monde", "."], + } + # Not apply token mask for not training example + ex_after = tokenmask_transform.apply(copy.deepcopy(ex), is_train=False) + self.assertEqual(ex_after, ex) + # apply token mask for training example + ex_after = tokenmask_transform.apply(copy.deepcopy(ex), is_train=True) + self.assertNotEqual(ex_after, ex) + + def test_switchout(self): + switchout_cls = get_transforms_cls(["switchout"])["switchout"] + opt = Namespace(seed=3434, switchout_temperature=0.1) + switchout_transform = switchout_cls(opt) + with self.assertRaises(ValueError): + # require vocabs to warm_up + switchout_transform.warm_up(vocabs=None) + vocabs = { + "src": Namespace(itos=["A", "Fake", "vocab"]), + "tgt": Namespace(itos=["A", "Fake", "vocab"]), + } + switchout_transform.warm_up(vocabs=vocabs) + ex = { + "src": ["Hello", ",", "world", "."], + "tgt": ["Bonjour", "le", "monde", "."], + } + # Not apply token mask for not training example + ex_after = switchout_transform.apply(copy.deepcopy(ex), is_train=False) + self.assertEqual(ex_after, ex) + # apply token mask for training example + ex_after = switchout_transform.apply(copy.deepcopy(ex), is_train=True) + self.assertNotEqual(ex_after, ex) + + +class TestBARTNoising(unittest.TestCase): + def setUp(self): + BARTNoising.set_random_seed(1234) + self.MASK_TOK = "[MASK]" + self.FAKE_VOCAB = "[TESTING]" + + def test_sentence_permute(self): + sent1 = ["Hello", "world", "."] + sent2 = ["Sentence", "1", "!"] + sent3 = ["Sentence", "2", "!"] + sent4 = ["Sentence", "3", "!"] + + bart_noise = BARTNoising( + vocab=[self.FAKE_VOCAB], + permute_sent_ratio=0.5, + replace_length=0, # not raise Error + # Defalt: full_stop_token=[".", "?", "!"] + ) + tokens = sent1 + sent2 + sent3 + sent4 + ends = bart_noise._get_sentence_borders(tokens).tolist() + self.assertEqual(ends, [3, 6, 9, 12]) + tokens_perm = bart_noise.apply(tokens) + expected_tokens = sent2 + sent1 + sent3 + sent4 + self.assertEqual(expected_tokens, tokens_perm) + + def test_rotate(self): + bart_noise = BARTNoising( + vocab=[self.FAKE_VOCAB], + rotate_ratio=1.0, + replace_length=0, # not raise Error + ) + tokens = ["This", "looks", "really", "good", "!"] + rotated = bart_noise.apply(tokens) + self.assertNotEqual(tokens, rotated) + not_rotate = bart_noise.rolling_noise(tokens, p=0.0) + self.assertEqual(tokens, not_rotate) + + def test_token_insert(self): + bart_noise = BARTNoising( + vocab=[self.FAKE_VOCAB], + mask_tok=self.MASK_TOK, + insert_ratio=0.5, + random_ratio=0.3, + replace_length=0, # not raise Error + # Defalt: full_stop_token=[".", "?", "!"] + ) + tokens = ["This", "looks", "really", "good", "!"] + inserted = bart_noise.apply(tokens) + n_insert = math.ceil(len(tokens) * bart_noise.insert_ratio) + inserted_len = n_insert + len(tokens) + self.assertEqual(len(inserted), inserted_len) + # random_ratio of inserted tokens are chosen in vocab + n_random = math.ceil(n_insert * bart_noise.random_ratio) + self.assertEqual( + sum(1 if tok == self.FAKE_VOCAB else 0 for tok in inserted), + n_random, + ) + # others are MASK_TOK + self.assertEqual( + sum(1 if tok == self.MASK_TOK else 0 for tok in inserted), + n_insert - n_random, + ) + + def test_token_mask(self): + """Mask will be done on token level. + + Condition: + * `mask_length` == subword; + * or not specify subword marker (joiner/spacer) by `is_joiner`. + """ + bart_noise = BARTNoising( + vocab=[self.FAKE_VOCAB], + mask_tok=self.MASK_TOK, + mask_ratio=0.5, + mask_length="subword", + replace_length=0, # 0 to drop them, 1 to replace them with MASK + # insert_ratio=0.0, + # random_ratio=0.0, + # Defalt: full_stop_token=[".", "?", "!"] + ) + tokens = ["H■", "ell■", "o", "world", "."] + # all token are considered as an individual word + self.assertTrue(all(bart_noise._is_word_start(tokens))) + n_tokens = len(tokens) + + # 1. tokens are dropped when replace_length is 0 + masked = bart_noise.apply(tokens) + n_masked = math.ceil(n_tokens * bart_noise.mask_ratio) + # print(f"token delete: {masked} / {tokens}") + self.assertEqual(len(masked), n_tokens - n_masked) + + # 2. tokens are replaced by MASK when replace_length is 1 + bart_noise.replace_length = 1 + masked = bart_noise.apply(tokens) + n_masked = math.ceil(n_tokens * bart_noise.mask_ratio) + # print(f"token mask: {masked} / {tokens}") + self.assertEqual(len(masked), n_tokens) + self.assertEqual( + sum([1 if tok == self.MASK_TOK else 0 for tok in masked]), n_masked + ) + + def test_whole_word_mask(self): + """Mask will be done on whole word that may across multiply token. + + Condition: + * `mask_length` == word; + * specify subword marker in order to find word boundary. + """ + bart_noise = BARTNoising( + vocab=[self.FAKE_VOCAB], + mask_tok=self.MASK_TOK, + mask_ratio=0.5, + mask_length="word", + is_joiner=True, + replace_length=0, # 0 to drop them, 1 to replace them with MASK + # insert_ratio=0.0, + # random_ratio=0.0, + # Defalt: full_stop_token=[".", "?", "!"] + ) + tokens = ["H■", "ell■", "o", "wor■", "ld", "."] + # start token of word are identified using subword marker + token_starts = [True, False, False, True, False, True] + self.assertEqual(bart_noise._is_word_start(tokens), token_starts) + + # 1. replace_length 0: "words" are dropped + masked = bart_noise.apply(copy.copy(tokens)) + n_words = sum(token_starts) + n_masked = math.ceil(n_words * bart_noise.mask_ratio) + # print(f"word delete: {masked} / {tokens}") + # self.assertEqual(len(masked), n_words - n_masked) + + # 2. replace_length 1: "words" are replaced with a single MASK + bart_noise.replace_length = 1 + masked = bart_noise.apply(copy.copy(tokens)) + # print(f"whole word single mask: {masked} / {tokens}") + # len(masked) depend on number of tokens in select word + n_words = sum(token_starts) + n_masked = math.ceil(n_words * bart_noise.mask_ratio) + self.assertEqual( + sum(1 if tok == self.MASK_TOK else 0 for tok in masked), n_masked + ) + + # 3. replace_length -1: all tokens in "words" are replaced with MASK + bart_noise.replace_length = -1 + masked = bart_noise.apply(copy.copy(tokens)) + # print(f"whole word multi mask: {masked} / {tokens}") + self.assertEqual(len(masked), len(tokens)) # length won't change + n_words = sum(token_starts) + n_masked = math.ceil(n_words * bart_noise.mask_ratio) + # number of mask_tok depend on number of tokens in selected word + # number of MASK_TOK can be greater than n_masked + self.assertTrue( + sum(1 if tok == self.MASK_TOK else 0 for tok in masked) > n_masked + ) + + def test_span_infilling(self): + bart_noise = BARTNoising( + vocab=[self.FAKE_VOCAB], + mask_tok=self.MASK_TOK, + mask_ratio=0.5, + mask_length="span-poisson", + poisson_lambda=3.0, + is_joiner=True, + replace_length=1, + # insert_ratio=0.5, + # random_ratio=0.3, + # Defalt: full_stop_token=[".", "?", "!"] + ) + self.assertIsNotNone(bart_noise.mask_span_distribution) + tokens = ["H■", "ell■", "o", "world", ".", "An■", "other", "!"] + # start token of word are identified using subword marker + token_starts = [True, False, False, True, True, True, False, True] + self.assertEqual(bart_noise._is_word_start(tokens), token_starts) + bart_noise.apply(copy.copy(tokens)) + # n_words = sum(token_starts) + # n_masked = math.ceil(n_words * bart_noise.mask_ratio) + # print(f"Text Span Infilling: {infillied} / {tokens}") + # print(n_words, n_masked) diff --git a/onmt/transforms/bart.py b/onmt/transforms/bart.py index ed116dcf..212c20ef 100644 --- a/onmt/transforms/bart.py +++ b/onmt/transforms/bart.py @@ -88,6 +88,12 @@ class BARTNoising(object): self.mask_length = mask_length self.poisson_lambda = poisson_lambda + @staticmethod + def set_random_seed(seed): + """Call this before use to ensure reproducibility.""" + np.random.seed(seed) + torch.manual_seed(seed) + def _make_poisson(self, poisson_lambda): lambda_to_the_k = 1 e_to_the_minus_lambda = math.exp(-poisson_lambda) @@ -102,20 +108,25 @@ class BARTNoising(object): ps = torch.FloatTensor(ps) return torch.distributions.Categorical(ps) - def _is_full_stop(self, token): - return True if token in self.full_stop_token else False - - def permute_sentences(self, tokens, p=1.0): - if len(tokens) == 1: - return tokens - full_stops = np.array([self._is_full_stop(token) for token in tokens]) + def _get_sentence_borders(self, tokens): + """Return lengths of each sentence in the token sequence.""" + full_stops = np.array( + [ + True if token in self.full_stop_token else False + for token in tokens + ] + ) # Pretend it ends with a full stop so last span is a sentence full_stops[-1] = True - # Tokens that are full stops, where the previous token is not - sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero()[0] + 2 + sentence_lens = (full_stops[1:] * ~full_stops[:-1]).nonzero()[0] + 2 + return sentence_lens - n_sentences = sentence_ends.size + def permute_sentences(self, tokens, p=1.0): + if len(tokens) == 1: + return tokens + sentence_lens = self._get_sentence_borders(tokens) + n_sentences = sentence_lens.size if n_sentences == 1: return tokens @@ -129,8 +140,8 @@ class BARTNoising(object): result = [tok for tok in tokens] index = 0 for i in ordering: - sentence = tokens[(sentence_ends[i - 1] if i > 0 else 0): - sentence_ends[i]] + sentence = tokens[(sentence_lens[i - 1] if i > 0 else 0): + sentence_lens[i]] result[index:index + len(sentence)] = sentence index += len(sentence) assert len(result) == len(tokens), "Error when permute sentences." @@ -331,8 +342,7 @@ class BARTNoiseTransform(Transform): def _set_seed(self, seed): """set seed to ensure reproducibility.""" - np.random.seed(seed) - torch.manual_seed(seed) + BARTNoising.set_random_seed(seed) @classmethod def add_options(cls, parser): diff --git a/onmt/transforms/sampling.py b/onmt/transforms/sampling.py index c4a48e7b..b816ebce 100644 --- a/onmt/transforms/sampling.py +++ b/onmt/transforms/sampling.py @@ -1,7 +1,6 @@ """Transforms relate to hamming distance sampling.""" import random import numpy as np -from onmt.utils.logging import logger from onmt.constants import DefaultTokens from onmt.transforms import register_transform from .transform import Transform @@ -79,25 +78,19 @@ class SwitchOutTransform(HammingDistanceSamplingTransform): # 2. sample positions to corrput chosen_indices = self._sample_position(tokens, distance=n_chosen) # 3. sample corrupted values - out = [] - for (i, tok) in enumerate(tokens): - if i in chosen_indices: - tok = self._sample_replace(vocab, reject=tok) - out.append(tok) - else: - out.append(tok) + for i in chosen_indices: + tokens[i] = self._sample_replace(vocab, reject=tokens[i]) if stats is not None: stats.switchout(n_switchout=n_chosen, n_total=len(tokens)) - return out + return tokens def apply(self, example, is_train=False, stats=None, **kwargs): """Apply switchout to both src and tgt side tokens.""" if is_train: - src = self._switchout( + example['src'] = self._switchout( example['src'], self.vocabs['src'].itos, stats) - tgt = self._switchout( + example['tgt'] = self._switchout( example['tgt'], self.vocabs['tgt'].itos, stats) - example['src'], example['tgt'] = src, tgt return example def _repr_args(self): @@ -124,6 +117,7 @@ class TokenDropTransform(HammingDistanceSamplingTransform): self.temperature = self.opts.tokendrop_temperature def _token_drop(self, tokens, stats=None): + n_items = len(tokens) # 1. sample number of tokens to corrupt n_chosen = self._sample_distance(tokens, self.temperature) # 2. sample positions to corrput @@ -132,20 +126,19 @@ class TokenDropTransform(HammingDistanceSamplingTransform): out = [tok for (i, tok) in enumerate(tokens) if i not in chosen_indices] if stats is not None: - stats.token_drop(n_dropped=n_chosen, n_total=len(tokens)) + stats.token_drop(n_dropped=n_chosen, n_total=n_items) return out def apply(self, example, is_train=False, stats=None, **kwargs): """Apply token drop to both src and tgt side tokens.""" if is_train: - src = self._token_drop(example['src'], stats) - tgt = self._token_drop(example['tgt'], stats) - example['src'], example['tgt'] = src, tgt + example['src'] = self._token_drop(example['src'], stats) + example['tgt'] = self._token_drop(example['tgt'], stats) return example def _repr_args(self): """Return str represent key arguments for class.""" - return '{}={}'.format('worddrop_temperature', self.temperature) + return '{}={}'.format('tokendrop_temperature', self.temperature) @register_transform(name='tokenmask') @@ -179,20 +172,16 @@ class TokenMaskTransform(HammingDistanceSamplingTransform): # 2. sample positions to corrput chosen_indices = self._sample_position(tokens, distance=n_chosen) # 3. mask word on chosen position - out = [] - for (i, tok) in enumerate(tokens): - tok = self.MASK_TOK if i in chosen_indices else tok - out.append(tok) + for i in chosen_indices: + tokens[i] = self.MASK_TOK if stats is not None: stats.token_mask(n_masked=n_chosen, n_total=len(tokens)) - return out + return tokens def apply(self, example, is_train=False, stats=None, **kwargs): """Apply word drop to both src and tgt side tokens.""" if is_train: - src = self._token_mask(example['src'], stats) - tgt = self._token_mask(example['tgt'], stats) - example['src'], example['tgt'] = src, tgt + example['src'] = self._token_mask(example['src'], stats) return example def _repr_args(self): diff --git a/onmt/transforms/tokenize.py b/onmt/transforms/tokenize.py index 4f343e47..b01b10fb 100644 --- a/onmt/transforms/tokenize.py +++ b/onmt/transforms/tokenize.py @@ -153,14 +153,14 @@ class SentencePieceTransform(TokenizerTransform): sentence = ' '.join(tokens) nbest_size = self.tgt_subword_nbest if side == 'tgt' else \ self.src_subword_nbest - alpha = self.tgt_subword_alpha if side == 'tgt' else \ - self.src_subword_alpha if is_train is False or nbest_size in [0, 1]: # derterministic subwording segmented = sp_model.encode(sentence, out_type=str) else: # subword sampling when nbest_size > 1 or -1 # alpha should be 0.0 < alpha < 1.0 + alpha = self.tgt_subword_alpha if side == 'tgt' else \ + self.src_subword_alpha segmented = sp_model.encode( sentence, out_type=str, enable_sampling=True, alpha=alpha, nbest_size=nbest_size) @@ -208,26 +208,25 @@ class BPETransform(TokenizerTransform): """Load subword models.""" super().warm_up(None) from subword_nmt.apply_bpe import BPE, read_vocabulary - import codecs - src_codes = codecs.open(self.src_subword_model, encoding='utf-8') + # Load vocabulary file if provided and set threshold src_vocabulary, tgt_vocabulary = None, None if self.src_subword_vocab != "" and self.src_vocab_threshold > 0: - src_vocabulary = read_vocabulary( - codecs.open(self.src_subword_vocab, encoding='utf-8'), - self.src_vocab_threshold) + with open(self.src_subword_vocab, encoding='utf-8') as _sv: + src_vocabulary = read_vocabulary(_sv, self.src_vocab_threshold) if self.tgt_subword_vocab != "" and self.tgt_vocab_threshold > 0: - tgt_vocabulary = read_vocabulary( - codecs.open(self.tgt_subword_vocab, encoding='utf-8'), - self.tgt_vocab_threshold) - load_src_model = BPE(codes=src_codes, vocab=src_vocabulary) + with open(self.tgt_subword_vocab, encoding='utf-8') as _tv: + tgt_vocabulary = read_vocabulary(_tv, self.tgt_vocab_threshold) + # Load Subword Model + with open(self.src_subword_model, encoding='utf-8') as src_codes: + load_src_model = BPE(codes=src_codes, vocab=src_vocabulary) if self.share_vocab and (src_vocabulary == tgt_vocabulary): self.load_models = { 'src': load_src_model, 'tgt': load_src_model } else: - tgt_codes = codecs.open(self.tgt_subword_model, encoding='utf-8') - load_tgt_model = BPE(codes=tgt_codes, vocab=tgt_vocabulary) + with open(self.tgt_subword_model, encoding='utf-8') as tgt_codes: + load_tgt_model = BPE(codes=tgt_codes, vocab=tgt_vocabulary) self.load_models = { 'src': load_src_model, 'tgt': load_tgt_model @@ -236,7 +235,7 @@ class BPETransform(TokenizerTransform): def _tokenize(self, tokens, side='src', is_train=False): """Do bpe subword tokenize.""" bpe_model = self.load_models[side] - dropout = self.dropout[side] if is_train else 0 + dropout = self.dropout[side] if is_train else 0.0 segmented = bpe_model.segment_tokens(tokens, dropout=dropout) return segmented |