Welcome to mirror list, hosted at ThFree Co, Russian Federation.

sentencepiece_trainer.h « src - github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: a5c22d4005909a87ee82101f43f7f3798cbd8340 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
// Copyright 2018 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!

#ifndef SENTENCEPIECE_TRAINER_H_
#define SENTENCEPIECE_TRAINER_H_

#include <string>

#include "sentencepiece_processor.h"
#include "third_party/absl/container/flat_hash_map.h"

namespace sentencepiece {

class TrainerSpec;
class NormalizerSpec;

namespace pretokenizer {
class PretokenizerForTrainingInterface;
}  // namespace pretokenizer

// Iterator over the training sentences.
// Training sentences are loaded sequentially as follows:
//
// for (; !it.done(); it.Next()) {
//    const std::string &s = it.value();
// }
// RETURN_IF_ERROR(it.status());
//
class SentenceIterator {
 public:
  virtual ~SentenceIterator() {}
  // Returns true if iteration finishes (including error case).
  // Uses SentenceIterator::status() method to know whether
  // all sentences are loaded successfully.
  virtual bool done() const = 0;
  virtual void Next() = 0;
  virtual const std::string &value() const = 0;
  virtual util::Status status() const = 0;
};

class SentencePieceTrainer {
 public:
  // Trains SentencePiece model with `trainer_spec`.
  // Default `normalizer_spec` is used.
  // When `sentence_iterator` is passed, load sentences from the iterator.
  static util::Status Train(const TrainerSpec &trainer_spec,
                            SentenceIterator *sentence_iterator = nullptr,
                            std::string *serialized_model_proto = nullptr);

  // Trains SentencePiece model with `trainer_spec` and
  // `normalizer_spec`.
  // When `sentence_iterator` is passed, load sentences from the iterator.
  static util::Status Train(const TrainerSpec &trainer_spec,
                            const NormalizerSpec &normalizer_spec,
                            SentenceIterator *sentence_iterator = nullptr,
                            std::string *serialized_model_proto = nullptr);

  // Trains SentencePiece model with `trainer_spec`, `normalizer_spec`
  // and `denormalizer_spec`.
  // When `sentence_iterator` is passed, load sentences from the iterator.
  static util::Status Train(const TrainerSpec &trainer_spec,
                            const NormalizerSpec &normalizer_spec,
                            const NormalizerSpec &denormalizer_spec,
                            SentenceIterator *sentence_iterator = nullptr,
                            std::string *serialized_model_proto = nullptr);
  // Trains SentencePiece model with command-line string in `args`,
  // e.g.,
  // '--input=data --model_prefix=m --vocab_size=8192 model_type=unigram'
  // When `sentence_iterator` is passed, load sentences from the iterator.
  static util::Status Train(absl::string_view args,
                            SentenceIterator *sentence_iterator = nullptr,
                            std::string *serialized_model_proto = nullptr);

  // Trains SentencePiece model with mapin `kwargs`.
  // e.g., {{"input", "data"}, {"model_prefix, "m"}, {"vocab_size", "8192"}...}
  static util::Status Train(
      const absl::flat_hash_map<std::string, std::string> &kwargs,
      SentenceIterator *sentence_iterator = nullptr,
      std::string *serialized_model_proto = nullptr);

  // Handy function to make a normalizer spec from the pre-compiled
  // normalization name. Do not use this method in production as it crashes
  // When `name` is invalid. Useful for unittesting.
  static NormalizerSpec GetNormalizerSpec(absl::string_view name);

  // Populates necessary fields (precompiled_charmap) from
  // `NormalizerSpec::name` or `NormalizerSpec::normalization_rule_tsv`.
  static util::Status PopulateNormalizerSpec(NormalizerSpec *normalizer_spec,
                                             bool is_denormalizer = false);

  // Overrides `trainer_spec`, `normalizer_spec`, `denormalizer_spec` with the
  // absl::flat_hash_map in `kargs`.
  static util::Status MergeSpecsFromArgs(
      const absl::flat_hash_map<std::string, std::string> &kwargs,
      TrainerSpec *trainer_spec, NormalizerSpec *normalizer_spec,
      NormalizerSpec *denormalizer_spec);

  // Overrides `trainer_spec`, `normalizer_spec`, `denormalizer_spec` with the
  // command line flags in `args`.
  static util::Status MergeSpecsFromArgs(absl::string_view args,
                                         TrainerSpec *trainer_spec,
                                         NormalizerSpec *normalizer_spec,
                                         NormalizerSpec *denormalizer_spec);

  // Injects global pre-tokenizer that are applied in training time.
  // Pretokenizer is only used for extracting pieces.
  // TODO(taku): It would be better to inject per `trainer_spec`.
  static util::Status SetPretokenizerForTraining(
      const pretokenizer::PretokenizerForTrainingInterface *pretokenizer);

  // Returns the current pretokenizer. if no pretokenizer is defined, returns
  // nullptr.
  static const pretokenizer::PretokenizerForTrainingInterface *
  GetPretokenizerForTraining();

  // Helper function to set `field_name=value` in `message`.
  // When `field_name` is repeated, multiple values can be passed
  // with comma-separated values. `field_name` must not be a nested message.
  // The body of these functions are automatically generated with
  // data/gen_spec_parser.pl
  static util::Status SetProtoField(const std::string &name,
                                    const std::string &value,
                                    TrainerSpec *message);

  static util::Status SetProtoField(const std::string &name,
                                    const std::string &value,
                                    NormalizerSpec *message);

  // Populates model type from string representation, e.g., "bpe".
  // Supported model: "unigram", "bpe", "word", "char".
  static util::Status PopulateModelTypeFromString(absl::string_view type,
                                                  TrainerSpec *trainer_spec);

 private:
  SentencePieceTrainer() {}
  ~SentencePieceTrainer() {}
};

}  // namespace sentencepiece

#endif  // SENTENCEPIECE_TRAINER_H_