diff options
author | Taku Kudo <taku@google.com> | 2020-10-13 10:44:42 +0300 |
---|---|---|
committer | Taku Kudo <taku@google.com> | 2020-10-13 10:44:42 +0300 |
commit | 8e70143bb3ba32c1bfac7d9ac8f3f0dfe51be61f (patch) | |
tree | ea17ea9f78fd66855281e9b75efcf1d8c311c7da | |
parent | f7bc3dbfb6b9afb4b7323e01b200b83291ee9b34 (diff) |
support big-endian architecture
-rw-r--r-- | src/normalizer.cc | 36 | ||||
-rw-r--r-- | src/normalizer.h | 4 | ||||
-rw-r--r-- | src/sentencepiece_processor.cc | 10 | ||||
-rw-r--r-- | src/spm_train_main.cc | 34 |
4 files changed, 82 insertions, 2 deletions
diff --git a/src/normalizer.cc b/src/normalizer.cc index 7e342b5..8f10f12 100644 --- a/src/normalizer.cc +++ b/src/normalizer.cc @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License.! +#include "normalizer.h" + #include <utility> #include <vector> #include "common.h" -#include "normalizer.h" #include "third_party/absl/memory/memory.h" #include "third_party/absl/strings/match.h" #include "third_party/absl/strings/string_view.h" @@ -257,6 +258,9 @@ std::string Normalizer::EncodePrecompiledCharsMap( blob.append(string_util::EncodePOD<uint32>(trie_blob.size())); blob.append(trie_blob.data(), trie_blob.size()); blob.append(normalized.data(), normalized.size()); + + MaybeSwapEndian(&blob, trie_blob.size()).IgnoreError(); + return blob; } @@ -282,6 +286,36 @@ util::Status Normalizer::DecodePrecompiledCharsMap( return util::OkStatus(); } +util::Status Normalizer::MaybeSwapEndian(std::string *precompiled_chars_map, + uint32 trie_blob_size) { +#ifdef __BIG_ENDIAN__ + auto swap32 = [](uint32 x) -> uint32 { return __builtin_bswap32(x); }; + + auto blob = absl::string_view(precompiled_chars_map->data(), + precompiled_chars_map->size()); + + if (trie_blob_size == 0) { + if (blob.size() <= sizeof(trie_blob_size) || + !string_util::DecodePOD<uint32>( + absl::string_view(blob.data(), sizeof(trie_blob_size)), + &trie_blob_size)) { + return util::InternalError("Blob for normalization rule is broken."); + } + trie_blob_size = swap32(trie_blob_size); + } + + if (trie_blob_size + 1 >= precompiled_chars_map->size()) + return util::InternalError("Blob for normalization rule is broken."); + + uint32 *data = reinterpret_cast<uint32 *>( + const_cast<char *>(precompiled_chars_map->data())); + for (int i = 0; i <= trie_blob_size; ++i) data[i] = swap32(data[i]); + +#endif // __BIG_ENDIAN__ + + return util::OkStatus(); +} + PrefixMatcher::PrefixMatcher(const std::set<absl::string_view> &dic) { if (dic.empty()) return; std::vector<const char *> key; diff --git a/src/normalizer.h b/src/normalizer.h index ab12fac..b198722 100644 --- a/src/normalizer.h +++ b/src/normalizer.h @@ -95,6 +95,10 @@ class Normalizer { friend class Builder; + // Swap endian in `compiled_chars_map`. Only called big-endian machine. + static util::Status MaybeSwapEndian(std::string *compiled_chars_map, + uint32 trie_blob_size); + private: FRIEND_TEST(NormalizerTest, EncodeDecodePrecompiledCharsMapTest); diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc index df053fd..751519f 100644 --- a/src/sentencepiece_processor.cc +++ b/src/sentencepiece_processor.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License.! +#include "sentencepiece_processor.h" + #include <map> #include <set> #include <utility> @@ -22,7 +24,6 @@ #include "model_interface.h" #include "normalizer.h" #include "sentencepiece.pb.h" -#include "sentencepiece_processor.h" #include "third_party/absl/memory/memory.h" #include "third_party/absl/strings/numbers.h" #include "third_party/absl/strings/str_cat.h" @@ -77,6 +78,13 @@ util::Status SentencePieceProcessor::Load( std::unique_ptr<ModelProto> model_proto) { model_proto_ = std::move(model_proto); model_ = ModelFactory::Create(*model_proto_); + + if (!model_proto_->normalizer_spec().precompiled_charsmap().empty()) { + RETURN_IF_ERROR(normalizer::Normalizer::MaybeSwapEndian( + model_proto_->mutable_normalizer_spec()->mutable_precompiled_charsmap(), + 0)); + } + normalizer_ = absl::make_unique<normalizer::Normalizer>( model_proto_->normalizer_spec(), model_proto_->trainer_spec()); if (model_proto_->has_denormalizer_spec() && diff --git a/src/spm_train_main.cc b/src/spm_train_main.cc index 8a0912b..847b7e7 100644 --- a/src/spm_train_main.cc +++ b/src/spm_train_main.cc @@ -14,11 +14,13 @@ #include <map> +#include "filesystem.h" #include "init.h" #include "sentencepiece_model.pb.h" #include "sentencepiece_trainer.h" #include "third_party/absl/flags/flag.h" #include "third_party/absl/strings/ascii.h" +#include "third_party/absl/strings/str_join.h" #include "third_party/absl/strings/str_split.h" #include "util.h" @@ -79,11 +81,17 @@ ABSL_FLAG(bool, treat_whitespace_as_suffix, "treat whitespace marker as suffix instead of prefix."); ABSL_FLAG(std::string, control_symbols, "", "comma separated list of control symbols"); +ABSL_FLAG(std::string, control_symbols_file, "", + "load control_symbols from file."); ABSL_FLAG(std::string, user_defined_symbols, "", "comma separated list of user defined symbols"); +ABSL_FLAG(std::string, user_defined_symbols_file, "", + "load user_defined_symbols from file."); ABSL_FLAG(std::string, required_chars, "", "UTF8 characters in this flag are always used in the character " "set regardless of --character_coverage"); +ABSL_FLAG(std::string, required_chars_file, "", + "load required_chars from file."); ABSL_FLAG(bool, byte_fallback, kDefaultTrainerSpec.byte_fallback(), "decompose unknown pieces into UTF-8 byte pieces"); ABSL_FLAG(bool, vocabulary_output_piece_score, @@ -140,6 +148,15 @@ int main(int argc, char *argv[]) { CHECK(!absl::GetFlag(FLAGS_input).empty()); CHECK(!absl::GetFlag(FLAGS_model_prefix).empty()); + auto load_lines = [](absl::string_view filename) { + std::vector<std::string> lines; + auto input = sentencepiece::filesystem::NewReadableFile(filename); + CHECK_OK(input->status()); + std::string line; + while (input->ReadLine(&line)) lines.emplace_back(line); + return lines; + }; + // Populates the value from flags to spec. #define SetTrainerSpecFromFlag(name) \ trainer_spec.set_##name(absl::GetFlag(FLAGS_##name)); @@ -147,6 +164,12 @@ int main(int argc, char *argv[]) { #define SetNormalizerSpecFromFlag(name) \ normalizer_spec.set_##name(absl::GetFlag(FLAGS_##name)); +#define SetTrainerSpecFromFile(name) \ + if (!absl::GetFlag(FLAGS_##name##_file).empty()) { \ + const auto lines = load_lines(absl::GetFlag(FLAGS_##name##_file)); \ + trainer_spec.set_##name(absl::StrJoin(lines, "")); \ + } + #define SetRepeatedTrainerSpecFromFlag(name) \ if (!absl::GetFlag(FLAGS_##name).empty()) { \ for (const auto &v : \ @@ -155,6 +178,13 @@ int main(int argc, char *argv[]) { } \ } +#define SetRepeatedTrainerSpecFromFile(name) \ + if (!absl::GetFlag(FLAGS_##name##_file).empty()) { \ + for (const auto &v : load_lines(absl::GetFlag(FLAGS_##name##_file))) { \ + trainer_spec.add_##name(v); \ + } \ + } + SetRepeatedTrainerSpecFromFlag(input); SetTrainerSpecFromFlag(input_format); @@ -188,12 +218,16 @@ int main(int argc, char *argv[]) { SetTrainerSpecFromFlag(pad_piece); SetTrainerSpecFromFlag(unk_surface); SetTrainerSpecFromFlag(required_chars); + SetTrainerSpecFromFile(required_chars); SetTrainerSpecFromFlag(vocabulary_output_piece_score); SetRepeatedTrainerSpecFromFlag(accept_language); SetRepeatedTrainerSpecFromFlag(control_symbols); SetRepeatedTrainerSpecFromFlag(user_defined_symbols); SetTrainerSpecFromFlag(train_extremely_large_corpus); + SetRepeatedTrainerSpecFromFile(control_symbols); + SetRepeatedTrainerSpecFromFile(user_defined_symbols); + normalizer_spec.set_name(absl::GetFlag(FLAGS_normalization_rule_name)); SetNormalizerSpecFromFlag(normalization_rule_tsv); SetNormalizerSpecFromFlag(add_dummy_prefix); |