Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/OpenNMT-py.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorValentin Berkes <16121857+funboarder13920@users.noreply.github.com>2021-09-20 18:24:43 +0300
committerGitHub <noreply@github.com>2021-09-20 18:24:43 +0300
commitc8081afa3fab4e169c34a84f9304cf37184f97d7 (patch)
tree494ecfbf85b1fd2e0bb569f29617c0dec04910e2
parent1360088af4ba515dd30bdff4d55c0e37f89f8946 (diff)
fix pos encoding error + target prefixing error in LM decoding (#2099)
-rw-r--r--.github/workflows/push.yml2
-rw-r--r--data/data_lm/gen-beam-sol.txt8
-rw-r--r--data/data_lm/gen-nucleus-sampling-sol.txt14
-rw-r--r--data/data_lm/gen-sampling-beams-sol.txt14
-rw-r--r--data/data_lm/gen-sampling-sol.txt10
-rwxr-xr-xonmt/tests/pull_request_chk.sh2
-rw-r--r--onmt/tests/test_translator.py36
-rw-r--r--onmt/translate/beam_search.py1
-rw-r--r--onmt/translate/greedy_search.py1
-rw-r--r--onmt/translate/translator.py26
-rw-r--r--onmt/utils/misc.py4
11 files changed, 77 insertions, 41 deletions
diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml
index 66d892ef..6b58b2c7 100644
--- a/.github/workflows/push.yml
+++ b/.github/workflows/push.yml
@@ -287,7 +287,7 @@ jobs:
-src data/data_lm/src-gen.txt \
-verbose -batch_size 10 \
-beam_size 10 \
- -seed 1 \
+ -seed 2 \
-random_sampling_topk 50 \
-random_sampling_topp 0.95 \
-random_sampling_temp 1 \
diff --git a/data/data_lm/gen-beam-sol.txt b/data/data_lm/gen-beam-sol.txt
index a2543961..e6a2656b 100644
--- a/data/data_lm/gen-beam-sol.txt
+++ b/data/data_lm/gen-beam-sol.txt
@@ -1,7 +1,7 @@
you !
-in German Presidency in German Presidency in German Presidency in German Prime Minister in the Netherlands , in Poland , in the Middle East suggests that the US Presidency .
-the top of the top of the top of the top of the top of the future .
+in German German Presidency in German Presidency in German Presidency in German Presidency in the Netherlands .
+the future .
"
ignored .
-changes .
-800 m2 can be found here .
+do ?
+800 m2 in images .
diff --git a/data/data_lm/gen-nucleus-sampling-sol.txt b/data/data_lm/gen-nucleus-sampling-sol.txt
index 7412a361..f1c954c7 100644
--- a/data/data_lm/gen-nucleus-sampling-sol.txt
+++ b/data/data_lm/gen-nucleus-sampling-sol.txt
@@ -1,7 +1,7 @@
-, then fall of Off or your beach and beauty of spa offers you stroll through the city standard .
-within the German German German German Presidency &apos;s fourth administrator can .
-the New World War recorded a battle of the top of the top of the most representative of the top of the importance of the importance of the other very important elements . unpleasant dealer prices .
-&quot; unspoilt rusticity of the village to the village Hotel Dory of the village region on the village Cathedral , mixed the village LiĊĦice together with the village center of the village Nr. 1 child), 11-17 years .
-fine National orphans and young involved in young renewable energies .
-segments of biometric risk lie ?
-Spain .
+, ?
+in German Presidency in German Presidency programmes first cosy apartment in Slovene style .
+higher in New Christ .
+"
+fine figure on serving angles and the fine words 20ies last century Without density properly prepared to strive as soon as easy as a enterprise events to fallen into procurement procedures corresponding increasingly difficult it .
+have any chance , staggered according to the requirement or reproductive schools .
+serious , the giant .
diff --git a/data/data_lm/gen-sampling-beams-sol.txt b/data/data_lm/gen-sampling-beams-sol.txt
index 23d78ab8..9ae748e7 100644
--- a/data/data_lm/gen-sampling-beams-sol.txt
+++ b/data/data_lm/gen-sampling-beams-sol.txt
@@ -1,7 +1,7 @@
-you ! Yours Ina .
-inspired to try to make these reports .
-a few years of the importance in New World Cup last century .
-I think 24th , received , Gift Shops , received .
-fine words is yet called Carl von Wogau report , has been called specialized in the importance of forms of the fifth generation of forms of the fifth generation of opportunities .
-makes diesel indispensable in the future .
-80 NGOs for Russian power needs of state .
+you ! All your reservation &apos;s news .
+in German Presidency in German Presidency criteria .
+the top of the top , the future .
+and what happens it .
+fine words , where banks fail else has underlined the usefulness of the energy .
+happened without any difficulties in a success .
+the hotel should be pulled apart .
diff --git a/data/data_lm/gen-sampling-sol.txt b/data/data_lm/gen-sampling-sol.txt
index 63dabfd2..33a12ae0 100644
--- a/data/data_lm/gen-sampling-sol.txt
+++ b/data/data_lm/gen-sampling-sol.txt
@@ -1,7 +1,7 @@
you !
-in German Presidency in German Presidency in German Presidency in German Presidency in the Netherlands Presidency in this: only 13 countries in the Netherlands Presidency in this: only 13 countries , in the Netherlands Presidency of 25 century , the Treaty of 25 century , the Treaty of 25 century , the United States .
-the top of the top of the top of the top of the top of the top of the top of the top of the top of the top of the top of the top of the top of the top of the top of the top of the importance of the future .
+in German Presidency in German Presidency in the Netherlands Presidency in this: only 13 countries in the Netherlands , in this: only 13 countries in Germany " s Treaty of 25 century .
+the top of the top of the top of the top of the top of the top of the top of the top of the top of the top of the top of the top of the top of the future of the future .
" s famous Italian Prime Minister said John Leahy , Airbus Chief Operating Officer , Customers .
-fine words made in fine words about the fine words about fine words about fine bouquet of the South Tibet .
-do not know what do not know what extent than you do not know what is to do .
-the hotel , the hotel is a large search for a trip to shops and restaurants .
+fine words is administered by India , South Tibet as India , South Tibet as India , South Tibet as claimed by China in China in the United States .
+do not know what is doing here .
+the hotel is a few weeks ago , the first class is a big government .
diff --git a/onmt/tests/pull_request_chk.sh b/onmt/tests/pull_request_chk.sh
index 70cd7682..cc34a614 100755
--- a/onmt/tests/pull_request_chk.sh
+++ b/onmt/tests/pull_request_chk.sh
@@ -397,7 +397,7 @@ ${PYTHON} translate.py -model ${TEST_DIR}/test_model_lm.pt \
-src ${DATA_DIR}/data_lm/src-gen.txt \
-verbose -batch_size 10 \
-beam_size 10 \
- -seed 1 \
+ -seed 2 \
-random_sampling_topk 50 \
-random_sampling_topp 0.95 \
-random_sampling_temp 1 \
diff --git a/onmt/tests/test_translator.py b/onmt/tests/test_translator.py
new file mode 100644
index 00000000..80d42171
--- /dev/null
+++ b/onmt/tests/test_translator.py
@@ -0,0 +1,36 @@
+import unittest
+from onmt.translate import GeneratorLM
+import torch
+
+
+class TestGeneratorLM(unittest.TestCase):
+ def test_split_src_to_prevent_padding_target_prefix_is_none_when_equal_size(
+ self,
+ ):
+ src = torch.randint(0, 10, (5, 6))
+ src_lengths = 5 * torch.ones(5)
+ (
+ src,
+ src_lengths,
+ target_prefix,
+ ) = GeneratorLM.split_src_to_prevent_padding(src, src_lengths)
+ self.assertIsNone(target_prefix)
+
+ def test_split_src_to_prevent_padding_target_prefix_is_ok_when_different_size(
+ self,
+ ):
+ default_length = 5
+ src = torch.randint(0, 10, (default_length, 6))
+ src_lengths = default_length * torch.ones(6, dtype=torch.int)
+ new_length = 4
+ src_lengths[1] = new_length
+ (
+ src,
+ src_lengths,
+ target_prefix,
+ ) = GeneratorLM.split_src_to_prevent_padding(src, src_lengths)
+ self.assertTupleEqual(src.shape, (new_length, 6))
+ self.assertTupleEqual(target_prefix.shape, (1, 6))
+ self.assertTrue(
+ src_lengths.equal(new_length * torch.ones(6, dtype=torch.int))
+ )
diff --git a/onmt/translate/beam_search.py b/onmt/translate/beam_search.py
index 3550352e..b15cbf65 100644
--- a/onmt/translate/beam_search.py
+++ b/onmt/translate/beam_search.py
@@ -357,7 +357,6 @@ class BeamSearchLM(BeamSearchBase):
(fn_map_state, _, src_map,
target_prefix) = self.initialize_tile(
None, src_lengths, src_map, target_prefix)
- src = fn_map_state(src, dim=1)
if device is None:
device = src.device
diff --git a/onmt/translate/greedy_search.py b/onmt/translate/greedy_search.py
index 1c4a43a6..2db4d9c4 100644
--- a/onmt/translate/greedy_search.py
+++ b/onmt/translate/greedy_search.py
@@ -274,6 +274,5 @@ class GreedySearchLM(GreedySearch):
(fn_map_state, _, self.memory_lengths,
src_map) = super(GreedySearchLM, self).initialize(
None, src_lengths, src_map, device, target_prefix)
- src = fn_map_state(src, dim=1)
return fn_map_state, src, self.memory_lengths, src_map
diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py
index 4d37e982..4874870b 100644
--- a/onmt/translate/translator.py
+++ b/onmt/translate/translator.py
@@ -1004,20 +1004,23 @@ class GeneratorLM(Inference):
batch, src_vocabs, decode_strategy
)
- def split_src_to_prevent_padding(self, src, src_lengths):
+ @classmethod
+ def split_src_to_prevent_padding(cls, src, src_lengths):
min_len_batch = torch.min(src_lengths).item()
target_prefix = None
- if min_len_batch > 0 and min_len_batch <= src.size(0):
- # hack [min_len_batch-1:] because expect <bos>
- target_prefix = (
- src[min_len_batch - 1:]
- if min_len_batch > 0 and min_len_batch <= src.size(0)
- else None
- )
+ if min_len_batch > 0 and min_len_batch < src.size(0):
+ target_prefix = src[min_len_batch:]
src = src[:min_len_batch]
src_lengths[:] = min_len_batch
return src, src_lengths, target_prefix
+ def tile_to_beam_size_after_initial_step(self, fn_map_state, log_probs):
+ if fn_map_state is not None:
+ log_probs = fn_map_state(log_probs, dim=1)
+ self.model.decoder.map_state(fn_map_state)
+ log_probs = log_probs[-1]
+ return log_probs
+
def _translate_batch_with_strategy(
self, batch, src_vocabs, decode_strategy
):
@@ -1072,8 +1075,6 @@ class GeneratorLM(Inference):
src_map,
target_prefix=target_prefix,
)
- if fn_map_state is not None:
- self.model.decoder.map_state(fn_map_state)
# (4) Begin decoding step by step:
for step in range(decode_strategy.max_length):
@@ -1090,12 +1091,13 @@ class GeneratorLM(Inference):
src_vocabs,
memory_lengths=memory_lengths.clone(),
src_map=src_map,
- step=step,
+ step=step if step == 0 else step + src_lengths[0].item(),
batch_offset=decode_strategy.batch_offset,
)
if step == 0:
- log_probs = log_probs[-1]
+ log_probs = self.tile_to_beam_size_after_initial_step(
+ fn_map_state, log_probs)
decode_strategy.advance(log_probs, attn)
any_finished = decode_strategy.is_finished.any()
diff --git a/onmt/utils/misc.py b/onmt/utils/misc.py
index b13f362a..1a539d74 100644
--- a/onmt/utils/misc.py
+++ b/onmt/utils/misc.py
@@ -72,11 +72,11 @@ def tile(x, count, dim=0):
perm = list(range(len(x.size())))
if dim != 0:
perm[0], perm[dim] = perm[dim], perm[0]
- x = x.permute(perm).contiguous()
+ x = x.permute(perm)
out_size = list(x.size())
out_size[0] *= count
batch = x.size(0)
- x = x.view(batch, -1) \
+ x = x.contiguous().view(batch, -1) \
.transpose(0, 1) \
.repeat(count, 1) \
.transpose(0, 1) \