Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/OpenNMT-tf.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGuillaume Klein <guillaumekln@users.noreply.github.com>2022-06-28 20:07:09 +0300
committerGitHub <noreply@github.com>2022-06-28 20:07:09 +0300
commit5c65d057c21f856ccb6c5b3f0dbb5693d277303e (patch)
tree8021e973308299498e39800324b5f7ea29c48173
parent36c737d1446e475e87b71519a6e7791b22a0f919 (diff)
Fix dtype of SequenceRecord's length feature (#953)
-rw-r--r--opennmt/inputters/record_inputter.py2
-rw-r--r--opennmt/tests/inputter_test.py1
2 files changed, 2 insertions, 1 deletions
diff --git a/opennmt/inputters/record_inputter.py b/opennmt/inputters/record_inputter.py
index dc5d6890..26fbd5cd 100644
--- a/opennmt/inputters/record_inputter.py
+++ b/opennmt/inputters/record_inputter.py
@@ -48,7 +48,7 @@ class SequenceRecordInputter(Inputter):
},
)
tensor = feature_lists["values"]
- features["length"] = lengths["values"]
+ features["length"] = tf.cast(lengths["values"], tf.int32)
features["tensor"] = tf.cast(tensor, self.dtype)
return features
diff --git a/opennmt/tests/inputter_test.py b/opennmt/tests/inputter_test.py
index 48036665..f273d40c 100644
--- a/opennmt/tests/inputter_test.py
+++ b/opennmt/tests/inputter_test.py
@@ -770,6 +770,7 @@ class InputterTest(tf.test.TestCase):
features = next(iter(dataset))
lengths = features["length"]
tensors = features["tensor"]
+ self.assertEqual(lengths.dtype, tf.int32)
self.assertAllEqual(lengths, [3, 6, 1])
for length, tensor, expected_vector in zip(lengths, tensors, vectors):
self.assertAllClose(tensor[:length], expected_vector)