Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2020-10-13 10:44:42 +0300
committerTaku Kudo <taku@google.com>2020-10-13 10:44:42 +0300
commit8e70143bb3ba32c1bfac7d9ac8f3f0dfe51be61f (patch)
treeea17ea9f78fd66855281e9b75efcf1d8c311c7da
parentf7bc3dbfb6b9afb4b7323e01b200b83291ee9b34 (diff)
support big-endian architecture
-rw-r--r--src/normalizer.cc36
-rw-r--r--src/normalizer.h4
-rw-r--r--src/sentencepiece_processor.cc10
-rw-r--r--src/spm_train_main.cc34
4 files changed, 82 insertions, 2 deletions
diff --git a/src/normalizer.cc b/src/normalizer.cc
index 7e342b5..8f10f12 100644
--- a/src/normalizer.cc
+++ b/src/normalizer.cc
@@ -12,11 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
+#include "normalizer.h"
+
#include <utility>
#include <vector>
#include "common.h"
-#include "normalizer.h"
#include "third_party/absl/memory/memory.h"
#include "third_party/absl/strings/match.h"
#include "third_party/absl/strings/string_view.h"
@@ -257,6 +258,9 @@ std::string Normalizer::EncodePrecompiledCharsMap(
blob.append(string_util::EncodePOD<uint32>(trie_blob.size()));
blob.append(trie_blob.data(), trie_blob.size());
blob.append(normalized.data(), normalized.size());
+
+ MaybeSwapEndian(&blob, trie_blob.size()).IgnoreError();
+
return blob;
}
@@ -282,6 +286,36 @@ util::Status Normalizer::DecodePrecompiledCharsMap(
return util::OkStatus();
}
+util::Status Normalizer::MaybeSwapEndian(std::string *precompiled_chars_map,
+ uint32 trie_blob_size) {
+#ifdef __BIG_ENDIAN__
+ auto swap32 = [](uint32 x) -> uint32 { return __builtin_bswap32(x); };
+
+ auto blob = absl::string_view(precompiled_chars_map->data(),
+ precompiled_chars_map->size());
+
+ if (trie_blob_size == 0) {
+ if (blob.size() <= sizeof(trie_blob_size) ||
+ !string_util::DecodePOD<uint32>(
+ absl::string_view(blob.data(), sizeof(trie_blob_size)),
+ &trie_blob_size)) {
+ return util::InternalError("Blob for normalization rule is broken.");
+ }
+ trie_blob_size = swap32(trie_blob_size);
+ }
+
+ if (trie_blob_size + 1 >= precompiled_chars_map->size())
+ return util::InternalError("Blob for normalization rule is broken.");
+
+ uint32 *data = reinterpret_cast<uint32 *>(
+ const_cast<char *>(precompiled_chars_map->data()));
+ for (int i = 0; i <= trie_blob_size; ++i) data[i] = swap32(data[i]);
+
+#endif // __BIG_ENDIAN__
+
+ return util::OkStatus();
+}
+
PrefixMatcher::PrefixMatcher(const std::set<absl::string_view> &dic) {
if (dic.empty()) return;
std::vector<const char *> key;
diff --git a/src/normalizer.h b/src/normalizer.h
index ab12fac..b198722 100644
--- a/src/normalizer.h
+++ b/src/normalizer.h
@@ -95,6 +95,10 @@ class Normalizer {
friend class Builder;
+ // Swap endian in `compiled_chars_map`. Only called big-endian machine.
+ static util::Status MaybeSwapEndian(std::string *compiled_chars_map,
+ uint32 trie_blob_size);
+
private:
FRIEND_TEST(NormalizerTest, EncodeDecodePrecompiledCharsMapTest);
diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc
index df053fd..751519f 100644
--- a/src/sentencepiece_processor.cc
+++ b/src/sentencepiece_processor.cc
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
+#include "sentencepiece_processor.h"
+
#include <map>
#include <set>
#include <utility>
@@ -22,7 +24,6 @@
#include "model_interface.h"
#include "normalizer.h"
#include "sentencepiece.pb.h"
-#include "sentencepiece_processor.h"
#include "third_party/absl/memory/memory.h"
#include "third_party/absl/strings/numbers.h"
#include "third_party/absl/strings/str_cat.h"
@@ -77,6 +78,13 @@ util::Status SentencePieceProcessor::Load(
std::unique_ptr<ModelProto> model_proto) {
model_proto_ = std::move(model_proto);
model_ = ModelFactory::Create(*model_proto_);
+
+ if (!model_proto_->normalizer_spec().precompiled_charsmap().empty()) {
+ RETURN_IF_ERROR(normalizer::Normalizer::MaybeSwapEndian(
+ model_proto_->mutable_normalizer_spec()->mutable_precompiled_charsmap(),
+ 0));
+ }
+
normalizer_ = absl::make_unique<normalizer::Normalizer>(
model_proto_->normalizer_spec(), model_proto_->trainer_spec());
if (model_proto_->has_denormalizer_spec() &&
diff --git a/src/spm_train_main.cc b/src/spm_train_main.cc
index 8a0912b..847b7e7 100644
--- a/src/spm_train_main.cc
+++ b/src/spm_train_main.cc
@@ -14,11 +14,13 @@
#include <map>
+#include "filesystem.h"
#include "init.h"
#include "sentencepiece_model.pb.h"
#include "sentencepiece_trainer.h"
#include "third_party/absl/flags/flag.h"
#include "third_party/absl/strings/ascii.h"
+#include "third_party/absl/strings/str_join.h"
#include "third_party/absl/strings/str_split.h"
#include "util.h"
@@ -79,11 +81,17 @@ ABSL_FLAG(bool, treat_whitespace_as_suffix,
"treat whitespace marker as suffix instead of prefix.");
ABSL_FLAG(std::string, control_symbols, "",
"comma separated list of control symbols");
+ABSL_FLAG(std::string, control_symbols_file, "",
+ "load control_symbols from file.");
ABSL_FLAG(std::string, user_defined_symbols, "",
"comma separated list of user defined symbols");
+ABSL_FLAG(std::string, user_defined_symbols_file, "",
+ "load user_defined_symbols from file.");
ABSL_FLAG(std::string, required_chars, "",
"UTF8 characters in this flag are always used in the character "
"set regardless of --character_coverage");
+ABSL_FLAG(std::string, required_chars_file, "",
+ "load required_chars from file.");
ABSL_FLAG(bool, byte_fallback, kDefaultTrainerSpec.byte_fallback(),
"decompose unknown pieces into UTF-8 byte pieces");
ABSL_FLAG(bool, vocabulary_output_piece_score,
@@ -140,6 +148,15 @@ int main(int argc, char *argv[]) {
CHECK(!absl::GetFlag(FLAGS_input).empty());
CHECK(!absl::GetFlag(FLAGS_model_prefix).empty());
+ auto load_lines = [](absl::string_view filename) {
+ std::vector<std::string> lines;
+ auto input = sentencepiece::filesystem::NewReadableFile(filename);
+ CHECK_OK(input->status());
+ std::string line;
+ while (input->ReadLine(&line)) lines.emplace_back(line);
+ return lines;
+ };
+
// Populates the value from flags to spec.
#define SetTrainerSpecFromFlag(name) \
trainer_spec.set_##name(absl::GetFlag(FLAGS_##name));
@@ -147,6 +164,12 @@ int main(int argc, char *argv[]) {
#define SetNormalizerSpecFromFlag(name) \
normalizer_spec.set_##name(absl::GetFlag(FLAGS_##name));
+#define SetTrainerSpecFromFile(name) \
+ if (!absl::GetFlag(FLAGS_##name##_file).empty()) { \
+ const auto lines = load_lines(absl::GetFlag(FLAGS_##name##_file)); \
+ trainer_spec.set_##name(absl::StrJoin(lines, "")); \
+ }
+
#define SetRepeatedTrainerSpecFromFlag(name) \
if (!absl::GetFlag(FLAGS_##name).empty()) { \
for (const auto &v : \
@@ -155,6 +178,13 @@ int main(int argc, char *argv[]) {
} \
}
+#define SetRepeatedTrainerSpecFromFile(name) \
+ if (!absl::GetFlag(FLAGS_##name##_file).empty()) { \
+ for (const auto &v : load_lines(absl::GetFlag(FLAGS_##name##_file))) { \
+ trainer_spec.add_##name(v); \
+ } \
+ }
+
SetRepeatedTrainerSpecFromFlag(input);
SetTrainerSpecFromFlag(input_format);
@@ -188,12 +218,16 @@ int main(int argc, char *argv[]) {
SetTrainerSpecFromFlag(pad_piece);
SetTrainerSpecFromFlag(unk_surface);
SetTrainerSpecFromFlag(required_chars);
+ SetTrainerSpecFromFile(required_chars);
SetTrainerSpecFromFlag(vocabulary_output_piece_score);
SetRepeatedTrainerSpecFromFlag(accept_language);
SetRepeatedTrainerSpecFromFlag(control_symbols);
SetRepeatedTrainerSpecFromFlag(user_defined_symbols);
SetTrainerSpecFromFlag(train_extremely_large_corpus);
+ SetRepeatedTrainerSpecFromFile(control_symbols);
+ SetRepeatedTrainerSpecFromFile(user_defined_symbols);
+
normalizer_spec.set_name(absl::GetFlag(FLAGS_normalization_rule_name));
SetNormalizerSpecFromFlag(normalization_rule_tsv);
SetNormalizerSpecFromFlag(add_dummy_prefix);