From 856daadbbfbf26da81152e70aba0406a11d5bedc Mon Sep 17 00:00:00 2001 From: Taku Kudo Date: Tue, 2 Jun 2020 01:56:48 +0900 Subject: Port absl::flat_hash_map --- src/bpe_model.cc | 8 ++++---- src/bpe_model.h | 2 +- src/bpe_model_trainer.cc | 3 ++- src/bpe_model_trainer.h | 6 +++--- src/builder.h | 2 +- src/char_model.h | 2 +- src/char_model_trainer.h | 2 +- src/common.h | 1 - src/freelist_test.cc | 14 ++++++++++++++ src/init_test.cc | 5 +++++ src/model_factory.h | 2 +- src/model_interface.cc | 4 ++-- src/model_interface.h | 8 ++++---- src/model_interface_test.cc | 7 +++---- src/normalizer.h | 2 +- src/pretokenizer_for_training.h | 2 +- src/sentencepiece_processor.cc | 2 +- src/sentencepiece_processor_test.cc | 12 ++++++------ src/sentencepiece_trainer.cc | 12 ++++++------ src/sentencepiece_trainer.h | 8 ++++---- src/sentencepiece_trainer_test.cc | 2 +- src/spm_decode_main.cc | 2 +- src/spm_encode_main.cc | 6 +++--- src/spm_export_vocab_main.cc | 2 +- src/spm_normalize_main.cc | 4 ++-- src/spm_train_main.cc | 2 +- src/trainer_factory.h | 2 +- src/trainer_interface.cc | 6 +++--- src/trainer_interface.h | 8 ++++---- src/trainer_interface_test.cc | 2 +- src/unicode_script.cc | 3 ++- src/unicode_script_map.h | 4 ++-- src/unigram_model.h | 2 +- src/unigram_model_test.cc | 2 +- src/unigram_model_trainer.cc | 10 +++++----- src/unigram_model_trainer.h | 2 +- src/unigram_model_trainer_test.cc | 2 +- src/word_model.h | 2 +- src/word_model_test.cc | 2 +- src/word_model_trainer.cc | 4 ++-- src/word_model_trainer.h | 2 +- third_party/absl/container/flat_hash_map.h | 29 +++++++++++++++++++++++++++++ third_party/absl/container/flat_hash_set.h | 29 +++++++++++++++++++++++++++++ 43 files changed, 155 insertions(+), 78 deletions(-) create mode 100644 third_party/absl/container/flat_hash_map.h create mode 100644 third_party/absl/container/flat_hash_set.h diff --git a/src/bpe_model.cc b/src/bpe_model.cc index b111f30..f1a97f4 100644 --- a/src/bpe_model.cc +++ b/src/bpe_model.cc @@ -16,12 +16,12 @@ #include #include #include -#include #include #include #include "bpe_model.h" #include "freelist.h" +#include "third_party/absl/container/flat_hash_map.h" #include "util.h" namespace sentencepiece { @@ -70,9 +70,9 @@ std::vector> Model::SampleEncode( // Reverse merge rules. // key: merged symbol, value: pair of original symbols. - std::unordered_map, - string_util::string_view_hash> + absl::flat_hash_map, + string_util::string_view_hash> rev_merge; // Pre-allocates SymbolPair for efficiency. diff --git a/src/bpe_model.h b/src/bpe_model.h index c6e1abe..8021d4e 100644 --- a/src/bpe_model.h +++ b/src/bpe_model.h @@ -15,8 +15,8 @@ #ifndef BPE_MODEL_H_ #define BPE_MODEL_H_ -#include "builtin_pb/sentencepiece_model.pb.h" #include "model_interface.h" +#include "sentencepiece_model.pb.h" namespace sentencepiece { namespace bpe { diff --git a/src/bpe_model_trainer.cc b/src/bpe_model_trainer.cc index 5a0cbdd..041df4a 100644 --- a/src/bpe_model_trainer.cc +++ b/src/bpe_model_trainer.cc @@ -18,6 +18,7 @@ #include #include "bpe_model_trainer.h" +#include "third_party/absl/container/flat_hash_set.h" #include "util.h" namespace sentencepiece { @@ -210,7 +211,7 @@ util::Status Trainer::Train() { // We may see duplicated pieces that are extracted with different path. // In real segmentation phase, we can consider them as one symbol. // e.g., "aaa" => "aa" + "a" or "a" + "aa". - std::unordered_set dup; + absl::flat_hash_set dup; // Main loop. CHECK_OR_RETURN(final_pieces_.empty()); diff --git a/src/bpe_model_trainer.h b/src/bpe_model_trainer.h index 051ac46..e011a37 100644 --- a/src/bpe_model_trainer.h +++ b/src/bpe_model_trainer.h @@ -17,10 +17,10 @@ #include #include -#include #include -#include "builtin_pb/sentencepiece_model.pb.h" +#include "sentencepiece_model.pb.h" +#include "third_party/absl/container/flat_hash_map.h" #include "trainer_interface.h" namespace sentencepiece { @@ -111,7 +111,7 @@ class Trainer : public TrainerInterface { void UpdateActiveSymbols(); // All unique symbols. Key is a fingerprint of Symbol. - std::unordered_map symbols_cache_; + absl::flat_hash_map symbols_cache_; // Set of symbols from which we find the best symbol in each iteration. std::set active_symbols_; diff --git a/src/builder.h b/src/builder.h index f0b959a..49d2884 100644 --- a/src/builder.h +++ b/src/builder.h @@ -19,8 +19,8 @@ #include #include -#include "builtin_pb/sentencepiece_model.pb.h" #include "common.h" +#include "sentencepiece_model.pb.h" #include "sentencepiece_processor.h" #include "third_party/absl/strings/string_view.h" diff --git a/src/char_model.h b/src/char_model.h index 23d0016..cd32875 100644 --- a/src/char_model.h +++ b/src/char_model.h @@ -15,8 +15,8 @@ #ifndef CHAR_MODEL_H_ #define CHAR_MODEL_H_ -#include "builtin_pb/sentencepiece_model.pb.h" #include "model_interface.h" +#include "sentencepiece_model.pb.h" namespace sentencepiece { namespace character { diff --git a/src/char_model_trainer.h b/src/char_model_trainer.h index f7b8a39..e563819 100644 --- a/src/char_model_trainer.h +++ b/src/char_model_trainer.h @@ -15,7 +15,7 @@ #ifndef CHAR_MODEL_TRAINER_H_ #define CHAR_MODEL_TRAINER_H_ -#include "builtin_pb/sentencepiece_model.pb.h" +#include "sentencepiece_model.pb.h" #include "trainer_interface.h" namespace sentencepiece { diff --git a/src/common.h b/src/common.h index 5d23e07..af0b1c2 100644 --- a/src/common.h +++ b/src/common.h @@ -15,7 +15,6 @@ #ifndef COMMON_H_ #define COMMON_H_ -#include #include #include #include diff --git a/src/freelist_test.cc b/src/freelist_test.cc index a7ff7de..9eb41a0 100644 --- a/src/freelist_test.cc +++ b/src/freelist_test.cc @@ -1,3 +1,17 @@ +// Copyright 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License.! + #include "freelist.h" #include "testharness.h" diff --git a/src/init_test.cc b/src/init_test.cc index da659bf..9007bec 100644 --- a/src/init_test.cc +++ b/src/init_test.cc @@ -24,6 +24,9 @@ ABSL_FLAG(uint64, uint64_f, 30, "uint64_flags"); ABSL_FLAG(double, double_f, 40.0, "double_flags"); ABSL_FLAG(std::string, string_f, "str", "string_flags"); +ABSL_DECLARE_FLAG(bool, help); +ABSL_DECLARE_FLAG(bool, version); + using sentencepiece::ParseCommandLineFlags; namespace absl { @@ -89,6 +92,7 @@ TEST(FlagsTest, ParseCommandLineFlagsHelpTest) { int argc = arraysize(kFlags); char **argv = const_cast(kFlags); EXPECT_DEATH(ParseCommandLineFlags(kFlags[0], &argc, &argv), ""); + absl::SetFlag(&FLAGS_help, false); } TEST(FlagsTest, ParseCommandLineFlagsVersionTest) { @@ -96,6 +100,7 @@ TEST(FlagsTest, ParseCommandLineFlagsVersionTest) { int argc = arraysize(kFlags); char **argv = const_cast(kFlags); EXPECT_DEATH(ParseCommandLineFlags(kFlags[0], &argc, &argv), ""); + absl::SetFlag(&FLAGS_version, false); } TEST(FlagsTest, ParseCommandLineFlagsUnknownTest) { diff --git a/src/model_factory.h b/src/model_factory.h index 0502af1..76abce7 100644 --- a/src/model_factory.h +++ b/src/model_factory.h @@ -17,8 +17,8 @@ #include -#include "builtin_pb/sentencepiece_model.pb.h" #include "model_interface.h" +#include "sentencepiece_model.pb.h" namespace sentencepiece { diff --git a/src/model_interface.cc b/src/model_interface.cc index 43dfbd1..ea5d0e7 100644 --- a/src/model_interface.cc +++ b/src/model_interface.cc @@ -14,8 +14,8 @@ #include -#include "builtin_pb/sentencepiece_model.pb.h" #include "model_interface.h" +#include "sentencepiece_model.pb.h" #include "third_party/absl/memory/memory.h" #include "third_party/absl/strings/str_format.h" #include "util.h" @@ -174,7 +174,7 @@ std::string ByteToPiece(unsigned char c) { } int PieceToByte(absl::string_view piece) { - using PieceToByteMap = std::unordered_map; + using PieceToByteMap = absl::flat_hash_map; static const auto *const kMap = []() -> PieceToByteMap * { auto *m = new PieceToByteMap(); for (int i = 0; i < 256; ++i) { diff --git a/src/model_interface.h b/src/model_interface.h index 27dad99..75cbb23 100644 --- a/src/model_interface.h +++ b/src/model_interface.h @@ -18,14 +18,14 @@ #include #include #include -#include #include #include -#include "builtin_pb/sentencepiece_model.pb.h" #include "common.h" #include "normalizer.h" +#include "sentencepiece_model.pb.h" #include "sentencepiece_processor.h" +#include "third_party/absl/container/flat_hash_map.h" #include "third_party/absl/strings/string_view.h" #include "third_party/darts_clone/darts.h" #include "util.h" @@ -52,8 +52,8 @@ class ModelProto; // Given a normalized string, returns a sequence of sentence pieces with ids. class ModelInterface { public: - using PieceToIdMap = - std::unordered_map; + using PieceToIdMap = absl::flat_hash_map; absl::string_view unk_piece() const; absl::string_view bos_piece() const; diff --git a/src/model_interface_test.cc b/src/model_interface_test.cc index 52b045d..f5ee492 100644 --- a/src/model_interface_test.cc +++ b/src/model_interface_test.cc @@ -12,11 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License.! -#include - #include "model_factory.h" #include "model_interface.h" #include "testharness.h" +#include "third_party/absl/container/flat_hash_map.h" #include "util.h" namespace sentencepiece { @@ -294,8 +293,8 @@ std::string RandomString(int length) { TEST(ModelInterfaceTest, PieceToIdStressTest) { for (const auto type : kModelTypes) { for (int i = 0; i < 100; ++i) { - std::unordered_map expected_p2i; - std::unordered_map expected_i2p; + absl::flat_hash_map expected_p2i; + absl::flat_hash_map expected_i2p; ModelProto model_proto = MakeBaseModelProto(type); for (int n = 0; n < 1000; ++n) { const std::string piece = RandomString(10); diff --git a/src/normalizer.h b/src/normalizer.h index 13166ca..ab12fac 100644 --- a/src/normalizer.h +++ b/src/normalizer.h @@ -21,8 +21,8 @@ #include #include -#include "builtin_pb/sentencepiece_model.pb.h" #include "common.h" +#include "sentencepiece_model.pb.h" #include "sentencepiece_processor.h" #include "third_party/absl/strings/string_view.h" #include "third_party/darts_clone/darts.h" diff --git a/src/pretokenizer_for_training.h b/src/pretokenizer_for_training.h index 0c84a08..2d3bc82 100644 --- a/src/pretokenizer_for_training.h +++ b/src/pretokenizer_for_training.h @@ -18,8 +18,8 @@ #include #include -#include "builtin_pb/sentencepiece.pb.h" #include "common.h" +#include "sentencepiece.pb.h" #include "sentencepiece_processor.h" #include "third_party/absl/strings/string_view.h" diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc index a4dd575..1e87a80 100644 --- a/src/sentencepiece_processor.cc +++ b/src/sentencepiece_processor.cc @@ -16,12 +16,12 @@ #include #include -#include "builtin_pb/sentencepiece.pb.h" #include "common.h" #include "filesystem.h" #include "model_factory.h" #include "model_interface.h" #include "normalizer.h" +#include "sentencepiece.pb.h" #include "sentencepiece_processor.h" #include "third_party/absl/memory/memory.h" #include "third_party/absl/strings/numbers.h" diff --git a/src/sentencepiece_processor_test.cc b/src/sentencepiece_processor_test.cc index cb669e7..ef54071 100644 --- a/src/sentencepiece_processor_test.cc +++ b/src/sentencepiece_processor_test.cc @@ -12,18 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License.! -#include #include #include "builder.h" -#include "builtin_pb/sentencepiece.pb.h" -#include "builtin_pb/sentencepiece_model.pb.h" #include "filesystem.h" #include "model_interface.h" #include "normalizer.h" +#include "sentencepiece.pb.h" +#include "sentencepiece_model.pb.h" #include "sentencepiece_processor.h" #include "sentencepiece_trainer.h" #include "testharness.h" +#include "third_party/absl/container/flat_hash_map.h" #include "third_party/absl/memory/memory.h" #include "third_party/absl/strings/str_cat.h" #include "third_party/absl/strings/string_view.h" @@ -551,8 +551,8 @@ TEST(SentencepieceProcessorTest, DecodeTest) { int GetPieceSize() const override { return 7; } int PieceToId(absl::string_view piece) const override { - static std::unordered_map + static absl::flat_hash_map kMap = {{"", 0}, {"", 1}, {"", 2}, {WS "ABC", 3}, {WS "DE", 4}, {"F", 5}, {"G" WS "H", 6}}; return port::FindWithDefault(kMap, piece, 0); @@ -695,7 +695,7 @@ TEST(SentencepieceProcessorTest, ByteFallbackDecodeTest) { } int PieceToId(absl::string_view piece) const override { - using Map = std::unordered_map; + using Map = absl::flat_hash_map; static const Map kMap = []() -> Map { Map m = { {"", 0}, {"", 1}, {"", 2}, {"A", 3}, {"B", 4}, {"C", 5}, diff --git a/src/sentencepiece_trainer.cc b/src/sentencepiece_trainer.cc index f2b5050..48cfda4 100644 --- a/src/sentencepiece_trainer.cc +++ b/src/sentencepiece_trainer.cc @@ -16,10 +16,10 @@ #include #include "builder.h" -#include "builtin_pb/sentencepiece.pb.h" -#include "builtin_pb/sentencepiece_model.pb.h" #include "common.h" #include "normalizer.h" +#include "sentencepiece.pb.h" +#include "sentencepiece_model.pb.h" #include "sentencepiece_trainer.h" #include "spec_parser.h" #include "third_party/absl/flags/flag.h" @@ -108,7 +108,7 @@ util::Status SentencePieceTrainer::MergeSpecsFromArgs( if (args.empty()) return util::OkStatus(); - std::unordered_map kwargs; + absl::flat_hash_map kwargs; for (auto arg : absl::StrSplit(args, " ")) { absl::ConsumePrefix(&arg, "--"); std::string key, value; @@ -128,7 +128,7 @@ util::Status SentencePieceTrainer::MergeSpecsFromArgs( // static util::Status SentencePieceTrainer::MergeSpecsFromArgs( - const std::unordered_map &kwargs, + const absl::flat_hash_map &kwargs, TrainerSpec *trainer_spec, NormalizerSpec *normalizer_spec, NormalizerSpec *denormalizer_spec) { CHECK_OR_RETURN(trainer_spec) << "`trainer_spec` must not be null."; @@ -188,7 +188,7 @@ util::Status SentencePieceTrainer::Train(absl::string_view args, // static util::Status SentencePieceTrainer::Train( - const std::unordered_map &kwargs, + const absl::flat_hash_map &kwargs, SentenceIterator *sentence_iterator, std::string *serialized_model_proto) { TrainerSpec trainer_spec; NormalizerSpec normalizer_spec; @@ -230,7 +230,7 @@ util::Status SentencePieceTrainer::PopulateNormalizerSpec( // static util::Status SentencePieceTrainer::PopulateModelTypeFromString( absl::string_view type, TrainerSpec *spec) { - static const std::unordered_map + static const absl::flat_hash_map kModelTypeMap = {{"unigram", TrainerSpec::UNIGRAM}, {"bpe", TrainerSpec::BPE}, {"word", TrainerSpec::WORD}, diff --git a/src/sentencepiece_trainer.h b/src/sentencepiece_trainer.h index bb74ab9..a5c22d4 100644 --- a/src/sentencepiece_trainer.h +++ b/src/sentencepiece_trainer.h @@ -16,9 +16,9 @@ #define SENTENCEPIECE_TRAINER_H_ #include -#include #include "sentencepiece_processor.h" +#include "third_party/absl/container/flat_hash_map.h" namespace sentencepiece { @@ -85,7 +85,7 @@ class SentencePieceTrainer { // Trains SentencePiece model with mapin `kwargs`. // e.g., {{"input", "data"}, {"model_prefix, "m"}, {"vocab_size", "8192"}...} static util::Status Train( - const std::unordered_map &kwargs, + const absl::flat_hash_map &kwargs, SentenceIterator *sentence_iterator = nullptr, std::string *serialized_model_proto = nullptr); @@ -100,9 +100,9 @@ class SentencePieceTrainer { bool is_denormalizer = false); // Overrides `trainer_spec`, `normalizer_spec`, `denormalizer_spec` with the - // std::unordered_map in `kargs`. + // absl::flat_hash_map in `kargs`. static util::Status MergeSpecsFromArgs( - const std::unordered_map &kwargs, + const absl::flat_hash_map &kwargs, TrainerSpec *trainer_spec, NormalizerSpec *normalizer_spec, NormalizerSpec *denormalizer_spec); diff --git a/src/sentencepiece_trainer_test.cc b/src/sentencepiece_trainer_test.cc index 9c5614f..e44e66b 100644 --- a/src/sentencepiece_trainer_test.cc +++ b/src/sentencepiece_trainer_test.cc @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License.! -#include "builtin_pb/sentencepiece_model.pb.h" #include "filesystem.h" +#include "sentencepiece_model.pb.h" #include "sentencepiece_trainer.h" #include "testharness.h" #include "third_party/absl/strings/str_cat.h" diff --git a/src/spm_decode_main.cc b/src/spm_decode_main.cc index 7284eb8..32cb382 100644 --- a/src/spm_decode_main.cc +++ b/src/spm_decode_main.cc @@ -16,10 +16,10 @@ #include #include -#include "builtin_pb/sentencepiece.pb.h" #include "common.h" #include "filesystem.h" #include "init.h" +#include "sentencepiece.pb.h" #include "sentencepiece_processor.h" #include "third_party/absl/flags/flag.h" #include "third_party/absl/strings/str_split.h" diff --git a/src/spm_encode_main.cc b/src/spm_encode_main.cc index 572cba5..4a51cb8 100644 --- a/src/spm_encode_main.cc +++ b/src/spm_encode_main.cc @@ -14,14 +14,14 @@ #include #include -#include #include -#include "builtin_pb/sentencepiece.pb.h" #include "common.h" #include "filesystem.h" #include "init.h" +#include "sentencepiece.pb.h" #include "sentencepiece_processor.h" +#include "third_party/absl/container/flat_hash_map.h" #include "third_party/absl/flags/flag.h" #include "third_party/absl/strings/str_cat.h" #include "third_party/absl/strings/str_join.h" @@ -83,7 +83,7 @@ int main(int argc, char *argv[]) { std::vector ids; std::vector> nbest_sps; std::vector> nbest_ids; - std::unordered_map vocab; + absl::flat_hash_map vocab; sentencepiece::SentencePieceText spt; sentencepiece::NBestSentencePieceText nbest_spt; std::function process; diff --git a/src/spm_export_vocab_main.cc b/src/spm_export_vocab_main.cc index 9b98f01..b5d93cb 100644 --- a/src/spm_export_vocab_main.cc +++ b/src/spm_export_vocab_main.cc @@ -15,10 +15,10 @@ #include -#include "builtin_pb/sentencepiece_model.pb.h" #include "common.h" #include "filesystem.h" #include "init.h" +#include "sentencepiece_model.pb.h" #include "sentencepiece_processor.h" #include "third_party/absl/flags/flag.h" diff --git a/src/spm_normalize_main.cc b/src/spm_normalize_main.cc index 244b974..96da360 100644 --- a/src/spm_normalize_main.cc +++ b/src/spm_normalize_main.cc @@ -13,12 +13,12 @@ // limitations under the License.! #include "builder.h" -#include "builtin_pb/sentencepiece.pb.h" -#include "builtin_pb/sentencepiece_model.pb.h" #include "common.h" #include "filesystem.h" #include "init.h" #include "normalizer.h" +#include "sentencepiece.pb.h" +#include "sentencepiece_model.pb.h" #include "sentencepiece_processor.h" #include "sentencepiece_trainer.h" #include "third_party/absl/flags/flag.h" diff --git a/src/spm_train_main.cc b/src/spm_train_main.cc index 6d990e0..8a0912b 100644 --- a/src/spm_train_main.cc +++ b/src/spm_train_main.cc @@ -14,8 +14,8 @@ #include -#include "builtin_pb/sentencepiece_model.pb.h" #include "init.h" +#include "sentencepiece_model.pb.h" #include "sentencepiece_trainer.h" #include "third_party/absl/flags/flag.h" #include "third_party/absl/strings/ascii.h" diff --git a/src/trainer_factory.h b/src/trainer_factory.h index d563f7d..a11cbc0 100644 --- a/src/trainer_factory.h +++ b/src/trainer_factory.h @@ -17,7 +17,7 @@ #include -#include "builtin_pb/sentencepiece_model.pb.h" +#include "sentencepiece_model.pb.h" #include "trainer_interface.h" namespace sentencepiece { diff --git a/src/trainer_interface.cc b/src/trainer_interface.cc index 5cdb300..eca7c8a 100644 --- a/src/trainer_interface.cc +++ b/src/trainer_interface.cc @@ -16,7 +16,6 @@ #include #include #include -#include #include #include @@ -26,6 +25,7 @@ #include "normalizer.h" #include "sentencepiece_processor.h" #include "sentencepiece_trainer.h" +#include "third_party/absl/container/flat_hash_map.h" #include "third_party/absl/memory/memory.h" #include "third_party/absl/strings/numbers.h" #include "third_party/absl/strings/str_cat.h" @@ -434,7 +434,7 @@ END: // Count character frequencies. int64 all_chars_count = 0; // A map from a character to {is_required_char, character count}. - std::unordered_map> chars_count; + absl::flat_hash_map> chars_count; for (const char32 c : string_util::UTF8ToUnicodeText(trainer_spec_.required_chars())) { CHECK_OR_RETURN(string_util::IsValidCodepoint(c)); @@ -526,7 +526,7 @@ END: void TrainerInterface::SplitSentencesByWhitespace() { LOG(INFO) << "Tokenizing input sentences with whitespace: " << sentences_.size(); - std::unordered_map tokens; + absl::flat_hash_map tokens; for (const auto &s : sentences_) { for (const auto &w : SplitIntoWords(s.first, trainer_spec_.treat_whitespace_as_suffix())) { diff --git a/src/trainer_interface.h b/src/trainer_interface.h index 552b206..f66d59a 100644 --- a/src/trainer_interface.h +++ b/src/trainer_interface.h @@ -19,15 +19,15 @@ #include #include #include -#include #include #include -#include "builtin_pb/sentencepiece_model.pb.h" #include "common.h" #include "filesystem.h" +#include "sentencepiece_model.pb.h" #include "sentencepiece_processor.h" #include "sentencepiece_trainer.h" +#include "third_party/absl/container/flat_hash_map.h" #include "util.h" namespace sentencepiece { @@ -44,7 +44,7 @@ std::vector> Sorted(const std::vector> &m) { } template -std::vector> Sorted(const std::unordered_map &m) { +std::vector> Sorted(const absl::flat_hash_map &m) { std::vector> v(m.begin(), m.end()); return Sorted(v); } @@ -129,7 +129,7 @@ class TrainerInterface { // Set of characters which must be included in the final vocab. // The value of this map stores the frequency. - std::unordered_map required_chars_; + absl::flat_hash_map required_chars_; // Final output pieces std::vector> final_pieces_; diff --git a/src/trainer_interface_test.cc b/src/trainer_interface_test.cc index 0144376..c61c7ce 100644 --- a/src/trainer_interface_test.cc +++ b/src/trainer_interface_test.cc @@ -466,7 +466,7 @@ TEST(TrainerInterfaceTest, CharactersTest) { trainer_spec.set_model_prefix("model"); trainer_spec.set_character_coverage(0.98); - using E = std::unordered_map; + using E = absl::flat_hash_map; { TrainerInterface trainer(trainer_spec, normalizer_spec, denormalizer_spec); EXPECT_OK(trainer.LoadSentences()); diff --git a/src/unicode_script.cc b/src/unicode_script.cc index 651b160..583dc30 100644 --- a/src/unicode_script.cc +++ b/src/unicode_script.cc @@ -14,6 +14,7 @@ #include +#include "third_party/absl/container/flat_hash_map.h" #include "unicode_script.h" #include "unicode_script_map.h" #include "util.h" @@ -30,7 +31,7 @@ class GetScriptInternal { } private: - std::unordered_map smap_; + absl::flat_hash_map smap_; }; } // namespace diff --git a/src/unicode_script_map.h b/src/unicode_script_map.h index 5e77c89..f2e67e9 100644 --- a/src/unicode_script_map.h +++ b/src/unicode_script_map.h @@ -14,11 +14,11 @@ #ifndef UNICODE_SCRIPT_DATA_H_ #define UNICODE_SCRIPT_DATA_H_ -#include +#include "third_party/absl/container/flat_hash_map.h" namespace sentencepiece { namespace unicode_script { namespace { -void InitTable(std::unordered_map *smap) { +void InitTable(absl::flat_hash_map *smap) { for (char32 c = 0x0000; c <= 0x001F; ++c) (*smap)[c] = U_Common; (*smap)[0x0020] = U_Common; for (char32 c = 0x0021; c <= 0x0023; ++c) (*smap)[c] = U_Common; diff --git a/src/unigram_model.h b/src/unigram_model.h index df84260..2f66a5f 100644 --- a/src/unigram_model.h +++ b/src/unigram_model.h @@ -20,10 +20,10 @@ #include #include -#include "builtin_pb/sentencepiece_model.pb.h" #include "common.h" #include "freelist.h" #include "model_interface.h" +#include "sentencepiece_model.pb.h" #include "third_party/darts_clone/darts.h" namespace sentencepiece { diff --git a/src/unigram_model_test.cc b/src/unigram_model_test.cc index e8ea0c6..dacec38 100644 --- a/src/unigram_model_test.cc +++ b/src/unigram_model_test.cc @@ -17,7 +17,7 @@ #include #include -#include "builtin_pb/sentencepiece_model.pb.h" +#include "sentencepiece_model.pb.h" #include "sentencepiece_processor.h" #include "testharness.h" #include "third_party/absl/strings/str_cat.h" diff --git a/src/unigram_model_trainer.cc b/src/unigram_model_trainer.cc index 99354af..86c7557 100644 --- a/src/unigram_model_trainer.cc +++ b/src/unigram_model_trainer.cc @@ -19,13 +19,13 @@ #include #include #include -#include #include #include #include "normalizer.h" #include "pretokenizer_for_training.h" #include "sentencepiece_trainer.h" +#include "third_party/absl/container/flat_hash_map.h" #include "third_party/absl/memory/memory.h" #include "third_party/esaxx/esa.hxx" // Suffix array library. #include "unicode_script.h" @@ -107,7 +107,7 @@ TrainerModel::SentencePieces Trainer::MakeSeedSentencePieces() const { // Merges all sentences into one array with 0x0000 delimiter. std::vector array; - std::unordered_map all_chars; + absl::flat_hash_map all_chars; constexpr char32 kSentenceBoundary = 0x0000; for (const auto &w : sentences_) { @@ -421,9 +421,9 @@ TrainerModel::SentencePieces Trainer::PruneSentencePieces( TrainerModel::SentencePieces Trainer::FinalizeSentencePieces( const TrainerModel &model) const { const auto &sentencepieces = model.GetSentencePieces(); - std::unordered_map final_sentencepieces; - std::unordered_map sp(sentencepieces.begin(), - sentencepieces.end()); + absl::flat_hash_map final_sentencepieces; + absl::flat_hash_map sp(sentencepieces.begin(), + sentencepieces.end()); // required_chars_ must be included in the final sentencepieces. float min_score_penalty = 0.0; diff --git a/src/unigram_model_trainer.h b/src/unigram_model_trainer.h index a0c1cea..91fbeb4 100644 --- a/src/unigram_model_trainer.h +++ b/src/unigram_model_trainer.h @@ -20,7 +20,7 @@ #include #include -#include "builtin_pb/sentencepiece_model.pb.h" +#include "sentencepiece_model.pb.h" #include "third_party/absl/strings/string_view.h" #include "trainer_interface.h" #include "unigram_model.h" diff --git a/src/unigram_model_trainer_test.cc b/src/unigram_model_trainer_test.cc index cca9936..ffe515e 100644 --- a/src/unigram_model_trainer_test.cc +++ b/src/unigram_model_trainer_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License.! -#include "builtin_pb/sentencepiece_model.pb.h" +#include "sentencepiece_model.pb.h" #include "sentencepiece_processor.h" #include "sentencepiece_trainer.h" #include "testharness.h" diff --git a/src/word_model.h b/src/word_model.h index 0048478..34470f9 100644 --- a/src/word_model.h +++ b/src/word_model.h @@ -15,8 +15,8 @@ #ifndef WORD_MODEL_H_ #define WORD_MODEL_H_ -#include "builtin_pb/sentencepiece_model.pb.h" #include "model_interface.h" +#include "sentencepiece_model.pb.h" namespace sentencepiece { namespace word { diff --git a/src/word_model_test.cc b/src/word_model_test.cc index 01c174c..aefb174 100644 --- a/src/word_model_test.cc +++ b/src/word_model_test.cc @@ -14,7 +14,7 @@ #include -#include "builtin_pb/sentencepiece_model.pb.h" +#include "sentencepiece_model.pb.h" #include "testharness.h" #include "util.h" #include "word_model.h" diff --git a/src/word_model_trainer.cc b/src/word_model_trainer.cc index fa6aeae..8d759e4 100644 --- a/src/word_model_trainer.cc +++ b/src/word_model_trainer.cc @@ -14,8 +14,8 @@ #include #include -#include +#include "third_party/absl/container/flat_hash_map.h" #include "third_party/absl/strings/string_view.h" #include "util.h" #include "word_model.h" @@ -32,7 +32,7 @@ util::Status Trainer::Train() { RETURN_IF_ERROR(LoadSentences()); - std::unordered_map freq; + absl::flat_hash_map freq; for (const auto &it : sentences_) { for (const auto &s : SplitIntoWords(it.first)) { freq[std::string(s)] += it.second; diff --git a/src/word_model_trainer.h b/src/word_model_trainer.h index 44aa657..76f8f32 100644 --- a/src/word_model_trainer.h +++ b/src/word_model_trainer.h @@ -15,7 +15,7 @@ #ifndef WORD_MODEL_TRAINER_H_ #define WORD_MODEL_TRAINER_H_ -#include "builtin_pb/sentencepiece_model.pb.h" +#include "sentencepiece_model.pb.h" #include "trainer_interface.h" namespace sentencepiece { diff --git a/third_party/absl/container/flat_hash_map.h b/third_party/absl/container/flat_hash_map.h new file mode 100644 index 0000000..aabed46 --- /dev/null +++ b/third_party/absl/container/flat_hash_map.h @@ -0,0 +1,29 @@ +// Copyright 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License.! + +#ifndef ABSL_CONTAINER_FLAT_HASH_MAP_ +#define ABSL_CONTAINER_FLAT_HASH_MAP_ + +#include + +namespace absl { + +template , + typename Eq = std::equal_to, + typename Allocator = std::allocator>> +using flat_hash_map = std::unordered_map; + +} + +#endif // ABSL_CONTAINER_FLAT_HASH_MAP_ diff --git a/third_party/absl/container/flat_hash_set.h b/third_party/absl/container/flat_hash_set.h new file mode 100644 index 0000000..199f866 --- /dev/null +++ b/third_party/absl/container/flat_hash_set.h @@ -0,0 +1,29 @@ +// Copyright 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License.! + +#ifndef ABSL_CONTAINER_FLAT_HASH_SET_ +#define ABSL_CONTAINER_FLAT_HASH_SET_ + +#include + +namespace absl { + +template , + typename Eq = std::equal_to, + typename Allocator = std::allocator> +using flat_hash_set = std::unordered_set; + +} + +#endif // ABSL_CONTAINER_FLAT_HASH_SET_ -- cgit v1.2.3