Welcome to mirror list, hosted at ThFree Co, Russian Federation.

record_inputter.py « inputters « opennmt - github.com/OpenNMT/OpenNMT-tf.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: dc5d6890978c3c8d81c2fe790a8558c0f8465ea4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""Define inputters reading from TFRecord files."""

import numpy as np
import tensorflow as tf

from opennmt.data import dataset as dataset_util
from opennmt.inputters.inputter import Inputter


class SequenceRecordInputter(Inputter):
    """Inputter that reads ``tf.train.SequenceExample``.

    See Also:
      :func:`opennmt.inputters.create_sequence_records` to generate a compatible
      dataset.
    """

    def __init__(self, input_depth, **kwargs):
        """Initializes the parameters of the record inputter.

        Args:
          input_depth: The depth dimension of the input vectors.
          **kwargs: Additional layer keyword arguments.
        """
        super().__init__(**kwargs)
        self.input_depth = input_depth

    def make_dataset(self, data_file, training=None):
        return dataset_util.make_datasets(tf.data.TFRecordDataset, data_file)

    def input_signature(self):
        return {
            "tensor": tf.TensorSpec([None, None, self.input_depth], self.dtype),
            "length": tf.TensorSpec([None], tf.int32),
        }

    def make_features(self, element=None, features=None, training=None):
        if features is None:
            features = {}
        if "tensor" in features:
            return features
        _, feature_lists, lengths = tf.io.parse_sequence_example(
            element,
            sequence_features={
                "values": tf.io.FixedLenSequenceFeature(
                    [self.input_depth], dtype=tf.float32
                )
            },
        )
        tensor = feature_lists["values"]
        features["length"] = lengths["values"]
        features["tensor"] = tf.cast(tensor, self.dtype)
        return features

    def call(self, features, training=None):
        return features["tensor"]


def write_sequence_record(vector, writer):
    """Writes a sequence vector as a TFRecord.

    Args:
      vector: A 2D Numpy float array of shape :math:`[T, D]`.
      writer: A ``tf.io.TFRecordWriter``.

    See Also:
      - :class:`opennmt.inputters.SequenceRecordInputter`
      - :func:`opennmt.inputters.create_sequence_records`
    """
    feature_list = tf.train.FeatureList(
        feature=[
            tf.train.Feature(float_list=tf.train.FloatList(value=values))
            for values in vector.astype(np.float32)
        ]
    )
    feature_lists = tf.train.FeatureLists(feature_list={"values": feature_list})
    example = tf.train.SequenceExample(feature_lists=feature_lists)
    writer.write(example.SerializeToString())


def create_sequence_records(vectors, path, compression=None):
    """Creates a TFRecord file of sequence vectors.

    Args:
      vectors: An iterable of 2D Numpy float arrays of shape :math:`[T, D]`.
      path: The output TFRecord file.
      compression: Optional compression type, can be "GZIP".

    Returns:
      Path to the TFRecord file. In most cases this is the same as :obj:`path` but
      if GZIP compression is enabled, the ".gz" extension is added if not already
      present.

    Raises:
      ValueError: if :obj:`compression` is invalid.

    See Also:
      - :class:`opennmt.inputters.SequenceRecordInputter`
      - :func:`opennmt.inputters.write_sequence_record`
    """
    if compression is not None:
        if compression not in ("GZIP",):
            raise ValueError("invalid compression type: %s" % compression)
        if compression == "GZIP" and not path.endswith(".gz"):
            path = "%s.gz" % path
    writer = tf.io.TFRecordWriter(path, options=compression)
    for vector in vectors:
        write_sequence_record(vector, writer)
    writer.close()
    return path