Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/OpenNMT-py.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLinxiao ZENG <linxiao.zeng@gmail.com>2021-04-26 11:36:59 +0300
committerGitHub <noreply@github.com>2021-04-26 11:36:59 +0300
commit1e8173cbfaba6f1248b6df53f5fd00892ce2e69c (patch)
treed6debcf84138e126d22a7cdebd5d8d1ad810522c
parent5c2eadada650b15f1bf46ec86f6e7008962efedd (diff)
Add some transforms unittest (#2049)
-rw-r--r--data/sample.bpe1001
-rw-r--r--data/sample.sp.modelbin0 -> 252917 bytes
-rw-r--r--onmt/tests/test_transform.py465
-rw-r--r--onmt/transforms/bart.py38
-rw-r--r--onmt/transforms/sampling.py39
-rw-r--r--onmt/transforms/tokenize.py27
6 files changed, 1517 insertions, 53 deletions
diff --git a/data/sample.bpe b/data/sample.bpe
new file mode 100644
index 00000000..59e66f35
--- /dev/null
+++ b/data/sample.bpe
@@ -0,0 +1,1001 @@
+#version: 0.2
+t h
+i n
+th e</w>
+a n
+r e
+t i
+e r
+e n
+o n
+a r
+o u
+o f</w>
+o n</w>
+o r
+an d</w>
+t o</w>
+t e
+r o
+in g</w>
+i s</w>
+i n</w>
+a l
+i t
+e s</w>
+i s
+e d</w>
+e r</w>
+a t</w>
+o r</w>
+a l</w>
+o m
+i c
+s t
+a ti
+s i
+a c
+a n</w>
+l y</w>
+e c
+a s</w>
+i l
+p o
+a t
+e s
+T h
+e l
+v e</w>
+m en
+u r
+p ro
+a s
+th at</w>
+f or</w>
+l i
+t s</w>
+l e</w>
+w h
+l e
+t r
+u s
+n o
+e n</w>
+c on
+d e
+ati on</w>
+s u
+h a
+ar e</w>
+l o
+a b
+a g
+a d
+te d</w>
+it h</w>
+q u
+u n
+Th e</w>
+o p
+c om
+w ith</w>
+b e</w>
+d i
+po s
+b e
+a m
+ti on</w>
+r i
+c h
+men t</w>
+ou r</w>
+i t</w>
+re s
+c e</w>
+l d</w>
+ou n
+p or
+e x
+p e
+on s</w>
+th is</w>
+u l
+f or
+v er
+n e
+o l
+th e
+t a
+w e</w>
+c i
+s h
+o u</w>
+o t
+a y</w>
+ha ve</w>
+i m
+er s</w>
+it y</w>
+er e</w>
+no t</w>
+f f
+v i
+g h
+a p
+b y</w>
+en t</w>
+ou ld</w>
+y ou</w>
+d u
+e m
+al l</w>
+il l</w>
+b u
+p l
+u ro
+at e</w>
+t er
+m o
+E uro
+w ill</w>
+& a
+&a pos
+s e
+ro m</w>
+i r
+ic h</w>
+a r</w>
+f rom</w>
+w or
+k e</w>
+te r</w>
+s t</w>
+wh ich</w>
+p ar
+m a
+s o</w>
+r a
+s e</w>
+p er
+h as</w>
+m e
+th er</w>
+ou t</w>
+o w
+d s</w>
+c o
+re n
+ab le</w>
+on e</w>
+c h</w>
+ur e</w>
+c an</w>
+si on</w>
+m an
+es s</w>
+s p
+e t</w>
+i es</w>
+e v
+ou r
+o o
+d is
+on al</w>
+C om
+an t</w>
+si d
+te s</w>
+al ly</w>
+re c
+&apos ;
+m is
+w as</w>
+ti c
+Euro pe
+g r
+Europe an</w>
+i r</w>
+por t</w>
+h o
+m e</w>
+us t</w>
+d ing</w>
+du c
+or e</w>
+a ted</w>
+c l
+W e</w>
+ag e</w>
+v e
+y our</w>
+n i
+in g
+al so</w>
+ti ons</w>
+a y
+t o
+u s</w>
+gh t</w>
+c oun
+g e</w>
+it s</w>
+t ing</w>
+b er</w>
+t u
+f in
+qu ot
+f e
+t w
+i c</w>
+r u
+s o
+&apos; s</w>
+com p
+u m
+h i
+& quot
+re d</w>
+Th is</w>
+k ing</w>
+c re
+men ts</w>
+d e</w>
+&quot ;</w>
+m ore</w>
+ati ons</w>
+s er
+be en</w>
+p res
+es s
+s te
+I n</w>
+t h</w>
+an ce</w>
+w ould</w>
+ou s</w>
+re g
+Com mis
+for m
+ver y</w>
+e ar
+M r</w>
+i ti
+d ed</w>
+sh ould</w>
+the ir</w>
+I t</w>
+us e</w>
+ar y</w>
+bu t</w>
+d er</w>
+v er</w>
+d o
+f i
+por t
+m in
+0 0
+c e
+g u
+g h</w>
+en t
+f ac
+a v
+es e</w>
+re sid
+ti ve</w>
+a u
+p r
+w i
+ec t</w>
+f u
+t y</w>
+si b
+g o
+es t</w>
+om e</w>
+ic e</w>
+o ther</w>
+o ff
+k s</w>
+t en
+no w</w>
+coun tr
+d er
+as e</w>
+l a
+P resid
+in e</w>
+p re
+s ti
+re s</w>
+v el
+con t
+p u
+P ar
+ab out</w>
+as t</w>
+w el
+d o</w>
+m ar
+d ay</w>
+Commis sion</w>
+h e
+a d</w>
+su p
+th ere</w>
+cl u
+the y</w>
+ne w</w>
+g e
+li ke</w>
+al l
+ec on
+en ce</w>
+pro duc
+s el
+com m
+k e
+ati onal</w>
+a in
+ar d</w>
+a m</w>
+in ter
+p e</w>
+m ust</w>
+an y</w>
+ec i
+ti me</w>
+Presid ent</w>
+c er
+r ou
+wor k</w>
+f ir
+a i
+m s</w>
+an s
+on ly</w>
+I n
+U ni
+de vel
+v is
+l ic
+o b
+te l</w>
+a re
+it e</w>
+v es</w>
+em ber</w>
+g i
+w e
+re e</w>
+2 00
+s y
+th ese</w>
+m on
+ad e</w>
+ci al</w>
+t or
+ac e</w>
+c u
+oun d</w>
+&apos ;</w>
+al s</w>
+si g
+ar i
+res p
+t e</w>
+im port
+for e</w>
+lo c
+devel op
+am e</w>
+u p</w>
+li a
+man y</w>
+wh o</w>
+oo d</w>
+l l</w>
+r on
+w ay</w>
+Euro pe</w>
+an d
+b ec
+m y</w>
+w ere</w>
+a f
+ren t</w>
+lo w
+wh at</w>
+e st
+ser v
+di ff
+ow n</w>
+p ri
+qu i
+su b
+ic al</w>
+ing s</w>
+ac c
+an s</w>
+at ing</w>
+c r
+pe op
+ac h</w>
+n o</w>
+su ch</w>
+t ur
+wi th
+il ity</w>
+n ational</w>
+t ure</w>
+S ta
+re e
+a te
+l is
+ne ed</w>
+ay s</w>
+mo st</w>
+en ti
+as s
+b i
+lia ment</w>
+sib le</w>
+c es</w>
+i on</w>
+peop le</w>
+ch an
+op er
+t ly</w>
+el i
+g y</w>
+iti es</w>
+s ec
+tic al</w>
+en er
+ol u
+vi e
+Par liament</w>
+am p
+i f</w>
+re port</w>
+b o
+ev er</w>
+in to</w>
+st ru
+. .
+p le
+sy ste
+th an</w>
+u e</w>
+p h
+r an
+fir st</w>
+in form
+Sta tes</w>
+countr ies</w>
+en tr
+f ul
+v ed</w>
+v en
+Uni on</w>
+il l
+E U</w>
+econ om
+p en
+wh en</w>
+c le
+fu l</w>
+s ome</w>
+tr ans
+men t
+tr i
+g re
+import ant</w>
+in te
+wel l</w>
+di rec
+W h
+a il
+C oun
+o c
+c ri
+m i
+sup port</w>
+th o
+re st
+ta ke</w>
+ci l</w>
+in clu
+pro b
+th ing</w>
+M ember</w>
+ati ve</w>
+h e</w>
+le g
+ma ke</w>
+ta in
+1 9
+a st
+h el
+A n
+C on
+po in
+Coun cil</w>
+l ar
+e ff
+ic i
+o ver</w>
+po lic
+e en</w>
+m p
+is e</w>
+c ol
+c our
+on g</w>
+en ts</w>
+g et</w>
+h is</w>
+il e</w>
+sp eci
+per i
+tw een</w>
+ic es</w>
+j o
+re l
+H o
+be tween</w>
+n um
+i g
+is h</w>
+j ust</w>
+po si
+prob le
+pro vi
+the m</w>
+ti es</w>
+a in</w>
+el l
+n ing</w>
+resp on
+u r</w>
+C h
+) .</w>
+or s</w>
+p s</w>
+tw o</w>
+y ear
+f un
+m u
+pro pos
+pu b
+po li
+I f</w>
+al ity</w>
+com e</w>
+inform ation</w>
+su l
+sti tu
+a use</w>
+m ade</w>
+n ec
+al i
+n ed</w>
+c ur
+d ec
+diff e
+ce p
+is su
+year s</w>
+ac h
+ap pro
+em b
+ow ever</w>
+ri ght</w>
+c an
+en d</w>
+it u
+m at
+re ad
+te c
+v ing</w>
+f ic
+t re
+us ed</w>
+.. .</w>
+ho tel</w>
+wh ere</w>
+c ar
+gh ts</w>
+0 0</w>
+ac k</w>
+h ere</w>
+ot h</w>
+r ic
+te ch
+de b
+e l</w>
+o f
+o m</w>
+y e
+B u
+c or
+pos sible</w>
+th rou
+j ec
+ou gh</w>
+pro c
+stru c
+e i
+e t
+tic ul
+un ity</w>
+con cer
+li c</w>
+or der</w>
+ver n
+l in
+polic y</w>
+sh i
+u m</w>
+per s
+ti c</w>
+a k
+th en</w>
+ac t</w>
+k et</w>
+ti on
+are a</w>
+f t</w>
+m ay</w>
+qu es
+s ur
+syste m</w>
+us in
+a tes</w>
+g ood</w>
+ta in</w>
+as ed</w>
+is h
+p p
+s itu
+c lo
+par t</w>
+st an
+ac ti
+c y</w>
+e d
+in cre
+m il
+par ticul
+e ve</w>
+re fore</w>
+b r
+in ce</w>
+om s</w>
+or y</w>
+poli tical</w>
+ac k
+ag es</w>
+develop ment</w>
+ha d</w>
+F or</w>
+ar ds</w>
+bec ause</w>
+e y</w>
+num ber</w>
+s c
+tho se</w>
+ur ing</w>
+Com m
+hi gh
+no w
+s ure</w>
+wor ld</w>
+R e
+af ter</w>
+for m</w>
+s m
+un der
+Th ere</w>
+be ing</w>
+comp le
+l es</w>
+pl ace</w>
+r y</w>
+cont in
+is t
+j ect</w>
+si ve</w>
+A l
+A s</w>
+ai d</w>
+bu il
+ho w</w>
+lo w</w>
+ma in
+om e
+ot e</w>
+si on
+Y ou</w>
+b eli
+c entr
+ev en</w>
+in v
+is t</w>
+v en</w>
+vie w</w>
+ye ar</w>
+f o
+f ol
+res s</w>
+& #
+b usin
+i d
+ro w
+c ould</w>
+ch ar
+diffe rent</w>
+econom ic</w>
+is m</w>
+ou s
+se e</w>
+tri bu
+countr y</w>
+d re
+e f
+en g
+k ed</w>
+u d
+w ant</w>
+an ts</w>
+p lo
+u res</w>
+with in</w>
+y ing</w>
+in k</w>
+A r
+c ed</w>
+d es</w>
+do es</w>
+e ment</w>
+fin d</w>
+or d
+S t
+a ir
+ag ree
+b oth</w>
+en s</w>
+ou se</w>
+ou t
+pl ic
+su m
+ar ly</w>
+f l
+k now</w>
+lo o
+pro tec
+pub lic</w>
+s pe
+v al
+a ter</w>
+i de
+d at
+en d
+mu ch</w>
+s ed</w>
+throu gh</w>
+ati c</w>
+fin an
+g an
+mar ket</w>
+se t</w>
+v ari
+Ho tel</w>
+c ul
+f am
+f ur
+i l</w>
+produc ts</w>
+sel f</w>
+sti ll</w>
+an k</w>
+cle ar</w>
+r ap
+st ar
+ve l</w>
+go vern
+men d
+t or</w>
+un der</w>
+i a</w>
+lo g
+re qui
+re sul
+s a
+te x
+ti onal</w>
+to day</w>
+v o
+f ree</w>
+gre at</w>
+l ine</w>
+olu tion</w>
+qu ality</w>
+read y</w>
+ri ghts</w>
+s ome
+st ra
+w s</w>
+ag ain
+c ap
+ci p
+de sig
+l and</w>
+pr in
+ro om</w>
+again st</w>
+c ase</w>
+con di
+ne e
+si ons</w>
+te n</w>
+Th at</w>
+av ail
+beli eve</w>
+e ver
+ex peri
+l an
+oper ation</w>
+cont ro
+g ener
+i z
+ta k
+te e</w>
+th er
+the refore</w>
+Bu t</w>
+G er
+ab ility</w>
+be fore</w>
+gr ou
+l ed</w>
+off er</w>
+s ame</w>
+th or
+ch il
+chan ge</w>
+ful ly</w>
+th ink</w>
+cu st
+gr am
+sid e</w>
+O ur</w>
+busin ess</w>
+c c
+l ast</w>
+m an</w>
+me di
+off ers</w>
+serv ices</w>
+vi ron
+hel p</w>
+n er</w>
+o w</w>
+par t
+proc ess</w>
+s en
+t able</w>
+ro s
+so cial</w>
+um an</w>
+w ays</w>
+b er
+b le</w>
+d ic
+k now
+lo b
+on g
+ro oms</w>
+t al
+t ro
+l ess</w>
+ma in</w>
+op e</w>
+deb ate</w>
+du st
+el e
+el ec
+in d</w>
+inter national</w>
+ment al</w>
+nec ess
+po w
+poin t</w>
+s ite</w>
+so ci
+P ro
+bu d
+c us
+can not</w>
+d ra
+eff ec
+im pro
+tur al</w>
+en viron
+m ul
+n at
+s ince</w>
+shi p</w>
+b re
+cour se</w>
+fe ren
+i ons</w>
+me as
+n s</w>
+reg ul
+s ay</w>
+s k
+sm all</w>
+ac cep
+li f
+m es</w>
+me ans</w>
+or gan
+posi tion</w>
+t le</w>
+to o</w>
+ve ly</w>
+with out</w>
+Commis sion
+c all
+c ell
+di vi
+dis cus
+p ur
+pro ce
+th ough</w>
+A f
+M ar
+ac coun
+al ready</w>
+ap p
+pl ay
+r s</w>
+&# 9
+1 0</w>
+A u
+In ter
+d uring</w>
+dat a</w>
+en cy</w>
+fu ture</w>
+in stitu
+oo k</w>
+r is
+r it
+E n
+H owever</w>
+ad v
+cust om
+d es
+ex p
+finan cial</w>
+low ing</w>
+situ ation</w>
+S u
+a ff
+a ir</w>
+avail able</w>
+g es</w>
+h uman</w>
+it al</w>
+r e</w>
+st and
+wh e
+2 0
+e p
+em plo
+ent ly</w>
+in dust
+issu e</w>
+o p</w>
diff --git a/data/sample.sp.model b/data/sample.sp.model
new file mode 100644
index 00000000..2e44c6ea
--- /dev/null
+++ b/data/sample.sp.model
Binary files differ
diff --git a/onmt/tests/test_transform.py b/onmt/tests/test_transform.py
new file mode 100644
index 00000000..f6fc3231
--- /dev/null
+++ b/onmt/tests/test_transform.py
@@ -0,0 +1,465 @@
+"""Here come the tests for implemented transform."""
+import unittest
+
+import copy
+import yaml
+import math
+from argparse import Namespace
+from onmt.transforms import get_transforms_cls, get_specials, make_transforms
+from onmt.transforms.bart import BARTNoising
+
+
+class TestTransform(unittest.TestCase):
+ def test_transform_register(self):
+ builtin_transform = [
+ "filtertoolong",
+ "prefix",
+ "sentencepiece",
+ "bpe",
+ "onmt_tokenize",
+ "bart",
+ "switchout",
+ "tokendrop",
+ "tokenmask",
+ ]
+ get_transforms_cls(builtin_transform)
+
+ def test_vocab_required_transform(self):
+ transforms_cls = get_transforms_cls(["bart", "switchout"])
+ opt = Namespace(seed=-1, switchout_temperature=1.0)
+ # transforms that require vocab will not create if not provide vocab
+ transforms = make_transforms(opt, transforms_cls, fields=None)
+ self.assertEqual(len(transforms), 0)
+ with self.assertRaises(ValueError):
+ transforms_cls["switchout"](opt).warm_up(vocabs=None)
+ transforms_cls["bart"](opt).warm_up(vocabs=None)
+
+ def test_transform_specials(self):
+ transforms_cls = get_transforms_cls(["prefix"])
+ corpora = yaml.safe_load("""
+ trainset:
+ path_src: data/src-train.txt
+ path_tgt: data/tgt-train.txt
+ transforms: ["prefix"]
+ weight: 1
+ src_prefix: "⦅_pf_src⦆"
+ tgt_prefix: "⦅_pf_tgt⦆"
+ """)
+ opt = Namespace(data=corpora)
+ specials = get_specials(opt, transforms_cls)
+ specials_expected = {"src": {"⦅_pf_src⦆"}, "tgt": {"⦅_pf_tgt⦆"}}
+ self.assertEqual(specials, specials_expected)
+
+
+class TestMiscTransform(unittest.TestCase):
+ def test_prefix(self):
+ prefix_cls = get_transforms_cls(["prefix"])["prefix"]
+ corpora = yaml.safe_load("""
+ trainset:
+ path_src: data/src-train.txt
+ path_tgt: data/tgt-train.txt
+ transforms: [prefix]
+ weight: 1
+ src_prefix: "⦅_pf_src⦆"
+ tgt_prefix: "⦅_pf_tgt⦆"
+ """)
+ opt = Namespace(data=corpora, seed=-1)
+ prefix_transform = prefix_cls(opt)
+ prefix_transform.warm_up()
+ self.assertIn("trainset", prefix_transform.prefix_dict)
+
+ ex_in = {
+ "src": ["Hello", "world", "."],
+ "tgt": ["Bonjour", "le", "monde", "."],
+ }
+ with self.assertRaises(ValueError):
+ prefix_transform.apply(ex_in)
+ prefix_transform.apply(ex_in, corpus_name="validset")
+ ex_out = prefix_transform.apply(ex_in, corpus_name="trainset")
+ self.assertEqual(ex_out["src"][0], "⦅_pf_src⦆")
+ self.assertEqual(ex_out["tgt"][0], "⦅_pf_tgt⦆")
+
+ def test_filter_too_long(self):
+ filter_cls = get_transforms_cls(["filtertoolong"])["filtertoolong"]
+ opt = Namespace(src_seq_length=100, tgt_seq_length=100)
+ filter_transform = filter_cls(opt)
+ # filter_transform.warm_up()
+ ex_in = {
+ "src": ["Hello", "world", "."],
+ "tgt": ["Bonjour", "le", "monde", "."],
+ }
+ ex_out = filter_transform.apply(ex_in)
+ self.assertIs(ex_out, ex_in)
+ filter_transform.tgt_seq_length = 2
+ ex_out = filter_transform.apply(ex_in)
+ self.assertIsNone(ex_out)
+
+
+class TestSubwordTransform(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.base_opts = {
+ "seed": 3431,
+ "share_vocab": False,
+ "src_subword_model": "data/sample.bpe",
+ "tgt_subword_model": "data/sample.bpe",
+ "src_subword_nbest": 1,
+ "tgt_subword_nbest": 1,
+ "src_subword_alpha": 0.0,
+ "tgt_subword_alpha": 0.0,
+ "src_subword_vocab": "",
+ "tgt_subword_vocab": "",
+ "src_vocab_threshold": 0,
+ "tgt_vocab_threshold": 0,
+ }
+
+ def test_bpe(self):
+ bpe_cls = get_transforms_cls(["bpe"])["bpe"]
+ opt = Namespace(**self.base_opts)
+ bpe_cls._validate_options(opt)
+ bpe_transform = bpe_cls(opt)
+ bpe_transform.warm_up()
+ ex = {
+ "src": ["Hello", "world", "."],
+ "tgt": ["Bonjour", "le", "monde", "."],
+ }
+ bpe_transform.apply(ex, is_train=True)
+ ex_gold = {
+ "src": ["H@@", "ell@@", "o", "world", "."],
+ "tgt": ["B@@", "on@@", "j@@", "our", "le", "mon@@", "de", "."],
+ }
+ self.assertEqual(ex, ex_gold)
+ # test BPE-dropout:
+ bpe_transform.dropout["src"] = 1.0
+ tokens = ["Another", "world", "."]
+ gold_bpe = ["A@@", "no@@", "ther", "world", "."]
+ gold_dropout = [
+ "A@@", "n@@", "o@@", "t@@", "h@@", "e@@", "r",
+ "w@@", "o@@", "r@@", "l@@", "d", ".",
+ ]
+ # 1. disable bpe dropout for not training example
+ after_bpe = bpe_transform._tokenize(tokens, is_train=False)
+ self.assertEqual(after_bpe, gold_bpe)
+ # 2. enable bpe dropout for training example
+ after_bpe = bpe_transform._tokenize(tokens, is_train=True)
+ self.assertEqual(after_bpe, gold_dropout)
+ # 3. (NOTE) disable dropout won't take effect if already seen
+ # this is caused by the cache mechanism in bpe:
+ # return cached subword if the original token is seen when no dropout
+ after_bpe2 = bpe_transform._tokenize(tokens, is_train=False)
+ self.assertEqual(after_bpe2, gold_dropout)
+
+ def test_sentencepiece(self):
+ sp_cls = get_transforms_cls(["sentencepiece"])["sentencepiece"]
+ base_opt = copy.copy(self.base_opts)
+ base_opt["src_subword_model"] = "data/sample.sp.model"
+ base_opt["tgt_subword_model"] = "data/sample.sp.model"
+ opt = Namespace(**base_opt)
+ sp_cls._validate_options(opt)
+ sp_transform = sp_cls(opt)
+ sp_transform.warm_up()
+ ex = {
+ "src": ["Hello", "world", "."],
+ "tgt": ["Bonjour", "le", "monde", "."],
+ }
+ sp_transform.apply(ex, is_train=True)
+ ex_gold = {
+ "src": ["▁H", "el", "lo", "▁world", "▁."],
+ "tgt": ["▁B", "on", "j", "o", "ur", "▁le", "▁m", "on", "de", "▁."],
+ }
+ self.assertEqual(ex, ex_gold)
+ # test SP regularization:
+ sp_transform.src_subword_nbest = 4
+ tokens = ["Another", "world", "."]
+ gold_sp = ["▁An", "other", "▁world", "▁."]
+ # 1. enable regularization for training example
+ after_sp = sp_transform._tokenize(tokens, is_train=True)
+ self.assertEqual(after_sp, ["▁An", "o", "ther", "▁world", "▁."])
+ # 2. disable regularization for not training example
+ after_sp = sp_transform._tokenize(tokens, is_train=False)
+ self.assertEqual(after_sp, gold_sp)
+
+ def test_pyonmttok_bpe(self):
+ onmttok_cls = get_transforms_cls(["onmt_tokenize"])["onmt_tokenize"]
+ base_opt = copy.copy(self.base_opts)
+ base_opt["src_subword_type"] = "bpe"
+ base_opt["tgt_subword_type"] = "bpe"
+ onmt_args = "{'mode': 'space', 'joiner_annotate': True}"
+ base_opt["src_onmttok_kwargs"] = onmt_args
+ base_opt["tgt_onmttok_kwargs"] = onmt_args
+ opt = Namespace(**base_opt)
+ onmttok_cls._validate_options(opt)
+ onmttok_transform = onmttok_cls(opt)
+ onmttok_transform.warm_up()
+ ex = {
+ "src": ["Hello", "world", "."],
+ "tgt": ["Bonjour", "le", "monde", "."],
+ }
+ onmttok_transform.apply(ex, is_train=True)
+ ex_gold = {
+ "src": ["H■", "ell■", "o", "world", "."],
+ "tgt": ["B■", "on■", "j■", "our", "le", "mon■", "de", "."],
+ }
+ self.assertEqual(ex, ex_gold)
+
+ def test_pyonmttok_sp(self):
+ onmttok_cls = get_transforms_cls(["onmt_tokenize"])["onmt_tokenize"]
+ base_opt = copy.copy(self.base_opts)
+ base_opt["src_subword_type"] = "sentencepiece"
+ base_opt["tgt_subword_type"] = "sentencepiece"
+ base_opt["src_subword_model"] = "data/sample.sp.model"
+ base_opt["tgt_subword_model"] = "data/sample.sp.model"
+ onmt_args = "{'mode': 'none', 'spacer_annotate': True}"
+ base_opt["src_onmttok_kwargs"] = onmt_args
+ base_opt["tgt_onmttok_kwargs"] = onmt_args
+ opt = Namespace(**base_opt)
+ onmttok_cls._validate_options(opt)
+ onmttok_transform = onmttok_cls(opt)
+ onmttok_transform.warm_up()
+ ex = {
+ "src": ["Hello", "world", "."],
+ "tgt": ["Bonjour", "le", "monde", "."],
+ }
+ onmttok_transform.apply(ex, is_train=True)
+ ex_gold = {
+ "src": ["▁H", "el", "lo", "▁world", "▁."],
+ "tgt": ["▁B", "on", "j", "o", "ur", "▁le", "▁m", "on", "de", "▁."],
+ }
+ self.assertEqual(ex, ex_gold)
+
+
+class TestSamplingTransform(unittest.TestCase):
+ def test_tokendrop(self):
+ tokendrop_cls = get_transforms_cls(["tokendrop"])["tokendrop"]
+ opt = Namespace(seed=3434, tokendrop_temperature=0.1)
+ tokendrop_transform = tokendrop_cls(opt)
+ tokendrop_transform.warm_up()
+ ex = {
+ "src": ["Hello", ",", "world", "."],
+ "tgt": ["Bonjour", "le", "monde", "."],
+ }
+ # Not apply token drop for not training example
+ ex_after = tokendrop_transform.apply(copy.deepcopy(ex), is_train=False)
+ self.assertEqual(ex_after, ex)
+ # apply token drop for training example
+ ex_after = tokendrop_transform.apply(copy.deepcopy(ex), is_train=True)
+ self.assertNotEqual(ex_after, ex)
+
+ def test_tokenmask(self):
+ tokenmask_cls = get_transforms_cls(["tokenmask"])["tokenmask"]
+ opt = Namespace(seed=3434, tokenmask_temperature=0.1)
+ tokenmask_transform = tokenmask_cls(opt)
+ tokenmask_transform.warm_up()
+ ex = {
+ "src": ["Hello", ",", "world", "."],
+ "tgt": ["Bonjour", "le", "monde", "."],
+ }
+ # Not apply token mask for not training example
+ ex_after = tokenmask_transform.apply(copy.deepcopy(ex), is_train=False)
+ self.assertEqual(ex_after, ex)
+ # apply token mask for training example
+ ex_after = tokenmask_transform.apply(copy.deepcopy(ex), is_train=True)
+ self.assertNotEqual(ex_after, ex)
+
+ def test_switchout(self):
+ switchout_cls = get_transforms_cls(["switchout"])["switchout"]
+ opt = Namespace(seed=3434, switchout_temperature=0.1)
+ switchout_transform = switchout_cls(opt)
+ with self.assertRaises(ValueError):
+ # require vocabs to warm_up
+ switchout_transform.warm_up(vocabs=None)
+ vocabs = {
+ "src": Namespace(itos=["A", "Fake", "vocab"]),
+ "tgt": Namespace(itos=["A", "Fake", "vocab"]),
+ }
+ switchout_transform.warm_up(vocabs=vocabs)
+ ex = {
+ "src": ["Hello", ",", "world", "."],
+ "tgt": ["Bonjour", "le", "monde", "."],
+ }
+ # Not apply token mask for not training example
+ ex_after = switchout_transform.apply(copy.deepcopy(ex), is_train=False)
+ self.assertEqual(ex_after, ex)
+ # apply token mask for training example
+ ex_after = switchout_transform.apply(copy.deepcopy(ex), is_train=True)
+ self.assertNotEqual(ex_after, ex)
+
+
+class TestBARTNoising(unittest.TestCase):
+ def setUp(self):
+ BARTNoising.set_random_seed(1234)
+ self.MASK_TOK = "[MASK]"
+ self.FAKE_VOCAB = "[TESTING]"
+
+ def test_sentence_permute(self):
+ sent1 = ["Hello", "world", "."]
+ sent2 = ["Sentence", "1", "!"]
+ sent3 = ["Sentence", "2", "!"]
+ sent4 = ["Sentence", "3", "!"]
+
+ bart_noise = BARTNoising(
+ vocab=[self.FAKE_VOCAB],
+ permute_sent_ratio=0.5,
+ replace_length=0, # not raise Error
+ # Defalt: full_stop_token=[".", "?", "!"]
+ )
+ tokens = sent1 + sent2 + sent3 + sent4
+ ends = bart_noise._get_sentence_borders(tokens).tolist()
+ self.assertEqual(ends, [3, 6, 9, 12])
+ tokens_perm = bart_noise.apply(tokens)
+ expected_tokens = sent2 + sent1 + sent3 + sent4
+ self.assertEqual(expected_tokens, tokens_perm)
+
+ def test_rotate(self):
+ bart_noise = BARTNoising(
+ vocab=[self.FAKE_VOCAB],
+ rotate_ratio=1.0,
+ replace_length=0, # not raise Error
+ )
+ tokens = ["This", "looks", "really", "good", "!"]
+ rotated = bart_noise.apply(tokens)
+ self.assertNotEqual(tokens, rotated)
+ not_rotate = bart_noise.rolling_noise(tokens, p=0.0)
+ self.assertEqual(tokens, not_rotate)
+
+ def test_token_insert(self):
+ bart_noise = BARTNoising(
+ vocab=[self.FAKE_VOCAB],
+ mask_tok=self.MASK_TOK,
+ insert_ratio=0.5,
+ random_ratio=0.3,
+ replace_length=0, # not raise Error
+ # Defalt: full_stop_token=[".", "?", "!"]
+ )
+ tokens = ["This", "looks", "really", "good", "!"]
+ inserted = bart_noise.apply(tokens)
+ n_insert = math.ceil(len(tokens) * bart_noise.insert_ratio)
+ inserted_len = n_insert + len(tokens)
+ self.assertEqual(len(inserted), inserted_len)
+ # random_ratio of inserted tokens are chosen in vocab
+ n_random = math.ceil(n_insert * bart_noise.random_ratio)
+ self.assertEqual(
+ sum(1 if tok == self.FAKE_VOCAB else 0 for tok in inserted),
+ n_random,
+ )
+ # others are MASK_TOK
+ self.assertEqual(
+ sum(1 if tok == self.MASK_TOK else 0 for tok in inserted),
+ n_insert - n_random,
+ )
+
+ def test_token_mask(self):
+ """Mask will be done on token level.
+
+ Condition:
+ * `mask_length` == subword;
+ * or not specify subword marker (joiner/spacer) by `is_joiner`.
+ """
+ bart_noise = BARTNoising(
+ vocab=[self.FAKE_VOCAB],
+ mask_tok=self.MASK_TOK,
+ mask_ratio=0.5,
+ mask_length="subword",
+ replace_length=0, # 0 to drop them, 1 to replace them with MASK
+ # insert_ratio=0.0,
+ # random_ratio=0.0,
+ # Defalt: full_stop_token=[".", "?", "!"]
+ )
+ tokens = ["H■", "ell■", "o", "world", "."]
+ # all token are considered as an individual word
+ self.assertTrue(all(bart_noise._is_word_start(tokens)))
+ n_tokens = len(tokens)
+
+ # 1. tokens are dropped when replace_length is 0
+ masked = bart_noise.apply(tokens)
+ n_masked = math.ceil(n_tokens * bart_noise.mask_ratio)
+ # print(f"token delete: {masked} / {tokens}")
+ self.assertEqual(len(masked), n_tokens - n_masked)
+
+ # 2. tokens are replaced by MASK when replace_length is 1
+ bart_noise.replace_length = 1
+ masked = bart_noise.apply(tokens)
+ n_masked = math.ceil(n_tokens * bart_noise.mask_ratio)
+ # print(f"token mask: {masked} / {tokens}")
+ self.assertEqual(len(masked), n_tokens)
+ self.assertEqual(
+ sum([1 if tok == self.MASK_TOK else 0 for tok in masked]), n_masked
+ )
+
+ def test_whole_word_mask(self):
+ """Mask will be done on whole word that may across multiply token.
+
+ Condition:
+ * `mask_length` == word;
+ * specify subword marker in order to find word boundary.
+ """
+ bart_noise = BARTNoising(
+ vocab=[self.FAKE_VOCAB],
+ mask_tok=self.MASK_TOK,
+ mask_ratio=0.5,
+ mask_length="word",
+ is_joiner=True,
+ replace_length=0, # 0 to drop them, 1 to replace them with MASK
+ # insert_ratio=0.0,
+ # random_ratio=0.0,
+ # Defalt: full_stop_token=[".", "?", "!"]
+ )
+ tokens = ["H■", "ell■", "o", "wor■", "ld", "."]
+ # start token of word are identified using subword marker
+ token_starts = [True, False, False, True, False, True]
+ self.assertEqual(bart_noise._is_word_start(tokens), token_starts)
+
+ # 1. replace_length 0: "words" are dropped
+ masked = bart_noise.apply(copy.copy(tokens))
+ n_words = sum(token_starts)
+ n_masked = math.ceil(n_words * bart_noise.mask_ratio)
+ # print(f"word delete: {masked} / {tokens}")
+ # self.assertEqual(len(masked), n_words - n_masked)
+
+ # 2. replace_length 1: "words" are replaced with a single MASK
+ bart_noise.replace_length = 1
+ masked = bart_noise.apply(copy.copy(tokens))
+ # print(f"whole word single mask: {masked} / {tokens}")
+ # len(masked) depend on number of tokens in select word
+ n_words = sum(token_starts)
+ n_masked = math.ceil(n_words * bart_noise.mask_ratio)
+ self.assertEqual(
+ sum(1 if tok == self.MASK_TOK else 0 for tok in masked), n_masked
+ )
+
+ # 3. replace_length -1: all tokens in "words" are replaced with MASK
+ bart_noise.replace_length = -1
+ masked = bart_noise.apply(copy.copy(tokens))
+ # print(f"whole word multi mask: {masked} / {tokens}")
+ self.assertEqual(len(masked), len(tokens)) # length won't change
+ n_words = sum(token_starts)
+ n_masked = math.ceil(n_words * bart_noise.mask_ratio)
+ # number of mask_tok depend on number of tokens in selected word
+ # number of MASK_TOK can be greater than n_masked
+ self.assertTrue(
+ sum(1 if tok == self.MASK_TOK else 0 for tok in masked) > n_masked
+ )
+
+ def test_span_infilling(self):
+ bart_noise = BARTNoising(
+ vocab=[self.FAKE_VOCAB],
+ mask_tok=self.MASK_TOK,
+ mask_ratio=0.5,
+ mask_length="span-poisson",
+ poisson_lambda=3.0,
+ is_joiner=True,
+ replace_length=1,
+ # insert_ratio=0.5,
+ # random_ratio=0.3,
+ # Defalt: full_stop_token=[".", "?", "!"]
+ )
+ self.assertIsNotNone(bart_noise.mask_span_distribution)
+ tokens = ["H■", "ell■", "o", "world", ".", "An■", "other", "!"]
+ # start token of word are identified using subword marker
+ token_starts = [True, False, False, True, True, True, False, True]
+ self.assertEqual(bart_noise._is_word_start(tokens), token_starts)
+ bart_noise.apply(copy.copy(tokens))
+ # n_words = sum(token_starts)
+ # n_masked = math.ceil(n_words * bart_noise.mask_ratio)
+ # print(f"Text Span Infilling: {infillied} / {tokens}")
+ # print(n_words, n_masked)
diff --git a/onmt/transforms/bart.py b/onmt/transforms/bart.py
index ed116dcf..212c20ef 100644
--- a/onmt/transforms/bart.py
+++ b/onmt/transforms/bart.py
@@ -88,6 +88,12 @@ class BARTNoising(object):
self.mask_length = mask_length
self.poisson_lambda = poisson_lambda
+ @staticmethod
+ def set_random_seed(seed):
+ """Call this before use to ensure reproducibility."""
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+
def _make_poisson(self, poisson_lambda):
lambda_to_the_k = 1
e_to_the_minus_lambda = math.exp(-poisson_lambda)
@@ -102,20 +108,25 @@ class BARTNoising(object):
ps = torch.FloatTensor(ps)
return torch.distributions.Categorical(ps)
- def _is_full_stop(self, token):
- return True if token in self.full_stop_token else False
-
- def permute_sentences(self, tokens, p=1.0):
- if len(tokens) == 1:
- return tokens
- full_stops = np.array([self._is_full_stop(token) for token in tokens])
+ def _get_sentence_borders(self, tokens):
+ """Return lengths of each sentence in the token sequence."""
+ full_stops = np.array(
+ [
+ True if token in self.full_stop_token else False
+ for token in tokens
+ ]
+ )
# Pretend it ends with a full stop so last span is a sentence
full_stops[-1] = True
-
# Tokens that are full stops, where the previous token is not
- sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero()[0] + 2
+ sentence_lens = (full_stops[1:] * ~full_stops[:-1]).nonzero()[0] + 2
+ return sentence_lens
- n_sentences = sentence_ends.size
+ def permute_sentences(self, tokens, p=1.0):
+ if len(tokens) == 1:
+ return tokens
+ sentence_lens = self._get_sentence_borders(tokens)
+ n_sentences = sentence_lens.size
if n_sentences == 1:
return tokens
@@ -129,8 +140,8 @@ class BARTNoising(object):
result = [tok for tok in tokens]
index = 0
for i in ordering:
- sentence = tokens[(sentence_ends[i - 1] if i > 0 else 0):
- sentence_ends[i]]
+ sentence = tokens[(sentence_lens[i - 1] if i > 0 else 0):
+ sentence_lens[i]]
result[index:index + len(sentence)] = sentence
index += len(sentence)
assert len(result) == len(tokens), "Error when permute sentences."
@@ -331,8 +342,7 @@ class BARTNoiseTransform(Transform):
def _set_seed(self, seed):
"""set seed to ensure reproducibility."""
- np.random.seed(seed)
- torch.manual_seed(seed)
+ BARTNoising.set_random_seed(seed)
@classmethod
def add_options(cls, parser):
diff --git a/onmt/transforms/sampling.py b/onmt/transforms/sampling.py
index c4a48e7b..b816ebce 100644
--- a/onmt/transforms/sampling.py
+++ b/onmt/transforms/sampling.py
@@ -1,7 +1,6 @@
"""Transforms relate to hamming distance sampling."""
import random
import numpy as np
-from onmt.utils.logging import logger
from onmt.constants import DefaultTokens
from onmt.transforms import register_transform
from .transform import Transform
@@ -79,25 +78,19 @@ class SwitchOutTransform(HammingDistanceSamplingTransform):
# 2. sample positions to corrput
chosen_indices = self._sample_position(tokens, distance=n_chosen)
# 3. sample corrupted values
- out = []
- for (i, tok) in enumerate(tokens):
- if i in chosen_indices:
- tok = self._sample_replace(vocab, reject=tok)
- out.append(tok)
- else:
- out.append(tok)
+ for i in chosen_indices:
+ tokens[i] = self._sample_replace(vocab, reject=tokens[i])
if stats is not None:
stats.switchout(n_switchout=n_chosen, n_total=len(tokens))
- return out
+ return tokens
def apply(self, example, is_train=False, stats=None, **kwargs):
"""Apply switchout to both src and tgt side tokens."""
if is_train:
- src = self._switchout(
+ example['src'] = self._switchout(
example['src'], self.vocabs['src'].itos, stats)
- tgt = self._switchout(
+ example['tgt'] = self._switchout(
example['tgt'], self.vocabs['tgt'].itos, stats)
- example['src'], example['tgt'] = src, tgt
return example
def _repr_args(self):
@@ -124,6 +117,7 @@ class TokenDropTransform(HammingDistanceSamplingTransform):
self.temperature = self.opts.tokendrop_temperature
def _token_drop(self, tokens, stats=None):
+ n_items = len(tokens)
# 1. sample number of tokens to corrupt
n_chosen = self._sample_distance(tokens, self.temperature)
# 2. sample positions to corrput
@@ -132,20 +126,19 @@ class TokenDropTransform(HammingDistanceSamplingTransform):
out = [tok for (i, tok) in enumerate(tokens)
if i not in chosen_indices]
if stats is not None:
- stats.token_drop(n_dropped=n_chosen, n_total=len(tokens))
+ stats.token_drop(n_dropped=n_chosen, n_total=n_items)
return out
def apply(self, example, is_train=False, stats=None, **kwargs):
"""Apply token drop to both src and tgt side tokens."""
if is_train:
- src = self._token_drop(example['src'], stats)
- tgt = self._token_drop(example['tgt'], stats)
- example['src'], example['tgt'] = src, tgt
+ example['src'] = self._token_drop(example['src'], stats)
+ example['tgt'] = self._token_drop(example['tgt'], stats)
return example
def _repr_args(self):
"""Return str represent key arguments for class."""
- return '{}={}'.format('worddrop_temperature', self.temperature)
+ return '{}={}'.format('tokendrop_temperature', self.temperature)
@register_transform(name='tokenmask')
@@ -179,20 +172,16 @@ class TokenMaskTransform(HammingDistanceSamplingTransform):
# 2. sample positions to corrput
chosen_indices = self._sample_position(tokens, distance=n_chosen)
# 3. mask word on chosen position
- out = []
- for (i, tok) in enumerate(tokens):
- tok = self.MASK_TOK if i in chosen_indices else tok
- out.append(tok)
+ for i in chosen_indices:
+ tokens[i] = self.MASK_TOK
if stats is not None:
stats.token_mask(n_masked=n_chosen, n_total=len(tokens))
- return out
+ return tokens
def apply(self, example, is_train=False, stats=None, **kwargs):
"""Apply word drop to both src and tgt side tokens."""
if is_train:
- src = self._token_mask(example['src'], stats)
- tgt = self._token_mask(example['tgt'], stats)
- example['src'], example['tgt'] = src, tgt
+ example['src'] = self._token_mask(example['src'], stats)
return example
def _repr_args(self):
diff --git a/onmt/transforms/tokenize.py b/onmt/transforms/tokenize.py
index 4f343e47..b01b10fb 100644
--- a/onmt/transforms/tokenize.py
+++ b/onmt/transforms/tokenize.py
@@ -153,14 +153,14 @@ class SentencePieceTransform(TokenizerTransform):
sentence = ' '.join(tokens)
nbest_size = self.tgt_subword_nbest if side == 'tgt' else \
self.src_subword_nbest
- alpha = self.tgt_subword_alpha if side == 'tgt' else \
- self.src_subword_alpha
if is_train is False or nbest_size in [0, 1]:
# derterministic subwording
segmented = sp_model.encode(sentence, out_type=str)
else:
# subword sampling when nbest_size > 1 or -1
# alpha should be 0.0 < alpha < 1.0
+ alpha = self.tgt_subword_alpha if side == 'tgt' else \
+ self.src_subword_alpha
segmented = sp_model.encode(
sentence, out_type=str, enable_sampling=True,
alpha=alpha, nbest_size=nbest_size)
@@ -208,26 +208,25 @@ class BPETransform(TokenizerTransform):
"""Load subword models."""
super().warm_up(None)
from subword_nmt.apply_bpe import BPE, read_vocabulary
- import codecs
- src_codes = codecs.open(self.src_subword_model, encoding='utf-8')
+ # Load vocabulary file if provided and set threshold
src_vocabulary, tgt_vocabulary = None, None
if self.src_subword_vocab != "" and self.src_vocab_threshold > 0:
- src_vocabulary = read_vocabulary(
- codecs.open(self.src_subword_vocab, encoding='utf-8'),
- self.src_vocab_threshold)
+ with open(self.src_subword_vocab, encoding='utf-8') as _sv:
+ src_vocabulary = read_vocabulary(_sv, self.src_vocab_threshold)
if self.tgt_subword_vocab != "" and self.tgt_vocab_threshold > 0:
- tgt_vocabulary = read_vocabulary(
- codecs.open(self.tgt_subword_vocab, encoding='utf-8'),
- self.tgt_vocab_threshold)
- load_src_model = BPE(codes=src_codes, vocab=src_vocabulary)
+ with open(self.tgt_subword_vocab, encoding='utf-8') as _tv:
+ tgt_vocabulary = read_vocabulary(_tv, self.tgt_vocab_threshold)
+ # Load Subword Model
+ with open(self.src_subword_model, encoding='utf-8') as src_codes:
+ load_src_model = BPE(codes=src_codes, vocab=src_vocabulary)
if self.share_vocab and (src_vocabulary == tgt_vocabulary):
self.load_models = {
'src': load_src_model,
'tgt': load_src_model
}
else:
- tgt_codes = codecs.open(self.tgt_subword_model, encoding='utf-8')
- load_tgt_model = BPE(codes=tgt_codes, vocab=tgt_vocabulary)
+ with open(self.tgt_subword_model, encoding='utf-8') as tgt_codes:
+ load_tgt_model = BPE(codes=tgt_codes, vocab=tgt_vocabulary)
self.load_models = {
'src': load_src_model,
'tgt': load_tgt_model
@@ -236,7 +235,7 @@ class BPETransform(TokenizerTransform):
def _tokenize(self, tokens, side='src', is_train=False):
"""Do bpe subword tokenize."""
bpe_model = self.load_models[side]
- dropout = self.dropout[side] if is_train else 0
+ dropout = self.dropout[side] if is_train else 0.0
segmented = bpe_model.segment_tokens(tokens, dropout=dropout)
return segmented