diff options
author | Guillaume Klein <guillaumekln@users.noreply.github.com> | 2022-06-28 20:07:09 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-06-28 20:07:09 +0300 |
commit | 5c65d057c21f856ccb6c5b3f0dbb5693d277303e (patch) | |
tree | 8021e973308299498e39800324b5f7ea29c48173 | |
parent | 36c737d1446e475e87b71519a6e7791b22a0f919 (diff) |
Fix dtype of SequenceRecord's length feature (#953)
-rw-r--r-- | opennmt/inputters/record_inputter.py | 2 | ||||
-rw-r--r-- | opennmt/tests/inputter_test.py | 1 |
2 files changed, 2 insertions, 1 deletions
diff --git a/opennmt/inputters/record_inputter.py b/opennmt/inputters/record_inputter.py index dc5d6890..26fbd5cd 100644 --- a/opennmt/inputters/record_inputter.py +++ b/opennmt/inputters/record_inputter.py @@ -48,7 +48,7 @@ class SequenceRecordInputter(Inputter): }, ) tensor = feature_lists["values"] - features["length"] = lengths["values"] + features["length"] = tf.cast(lengths["values"], tf.int32) features["tensor"] = tf.cast(tensor, self.dtype) return features diff --git a/opennmt/tests/inputter_test.py b/opennmt/tests/inputter_test.py index 48036665..f273d40c 100644 --- a/opennmt/tests/inputter_test.py +++ b/opennmt/tests/inputter_test.py @@ -770,6 +770,7 @@ class InputterTest(tf.test.TestCase): features = next(iter(dataset)) lengths = features["length"] tensors = features["tensor"] + self.assertEqual(lengths.dtype, tf.int32) self.assertAllEqual(lengths, [3, 6, 1]) for length, tensor, expected_vector in zip(lengths, tensors, vectors): self.assertAllClose(tensor[:length], expected_vector) |