Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2020-10-13 18:50:51 +0300
committerTaku Kudo <taku@google.com>2020-10-13 18:50:51 +0300
commit4d14b86ca03b5dbd04909f7dea60ab8a7e0c2600 (patch)
tree71af79ecbdd920c4d5b33cae660b2abafc5a86b2
parent8e70143bb3ba32c1bfac7d9ac8f3f0dfe51be61f (diff)
support big-endian architecture
-rw-r--r--src/normalizer.cc62
-rw-r--r--src/normalizer.h12
-rw-r--r--src/normalizer_test.cc16
-rw-r--r--src/sentencepiece_processor.cc6
-rw-r--r--src/util.h4
5 files changed, 47 insertions, 53 deletions
diff --git a/src/normalizer.cc b/src/normalizer.cc
index 8f10f12..7fe90a2 100644
--- a/src/normalizer.cc
+++ b/src/normalizer.cc
@@ -51,7 +51,12 @@ void Normalizer::Init() {
LOG(INFO) << "precompiled_charsmap is empty. use identity normalization.";
} else {
absl::string_view trie_blob, normalized;
+#ifdef __BIG_ENDIAN__
+ status_ = DecodePrecompiledCharsMap(index, &trie_blob, &normalized,
+ &precompiled_charsmap_buffer_);
+#else
status_ = DecodePrecompiledCharsMap(index, &trie_blob, &normalized);
+#endif
if (!status_.ok()) return;
// Reads the body of double array.
@@ -259,7 +264,11 @@ std::string Normalizer::EncodePrecompiledCharsMap(
blob.append(trie_blob.data(), trie_blob.size());
blob.append(normalized.data(), normalized.size());
- MaybeSwapEndian(&blob, trie_blob.size()).IgnoreError();
+#ifdef __BIG_ENDIAN__
+ uint32 *data = reinterpret_cast<uint32 *>(const_cast<char *>(blob.data()));
+ for (int i = 0; i <= trie_blob.size() / 4; ++i)
+ data[i] = util::Swap32(data[i]);
+#endif
return blob;
}
@@ -267,51 +276,36 @@ std::string Normalizer::EncodePrecompiledCharsMap(
// static
util::Status Normalizer::DecodePrecompiledCharsMap(
absl::string_view blob, absl::string_view *trie_blob,
- absl::string_view *normalized) {
+ absl::string_view *normalized, std::string *buffer) {
uint32 trie_blob_size = 0;
+
if (blob.size() <= sizeof(trie_blob_size) ||
!string_util::DecodePOD<uint32>(
absl::string_view(blob.data(), sizeof(trie_blob_size)),
- &trie_blob_size) ||
- trie_blob_size >= blob.size()) {
+ &trie_blob_size)) {
return util::InternalError("Blob for normalization rule is broken.");
}
- blob.remove_prefix(sizeof(trie_blob_size));
- *trie_blob = absl::string_view(blob.data(), trie_blob_size);
-
- blob.remove_prefix(trie_blob_size);
- *normalized = absl::string_view(blob.data(), blob.size());
-
- return util::OkStatus();
-}
-
-util::Status Normalizer::MaybeSwapEndian(std::string *precompiled_chars_map,
- uint32 trie_blob_size) {
#ifdef __BIG_ENDIAN__
- auto swap32 = [](uint32 x) -> uint32 { return __builtin_bswap32(x); };
+ trie_blob_size = util::Swap32(trie_blob_size);
+#endif
- auto blob = absl::string_view(precompiled_chars_map->data(),
- precompiled_chars_map->size());
+ if (trie_blob_size >= blob.size())
+ return util::InternalError("Trie data size exceeds the input blob size.");
- if (trie_blob_size == 0) {
- if (blob.size() <= sizeof(trie_blob_size) ||
- !string_util::DecodePOD<uint32>(
- absl::string_view(blob.data(), sizeof(trie_blob_size)),
- &trie_blob_size)) {
- return util::InternalError("Blob for normalization rule is broken.");
- }
- trie_blob_size = swap32(trie_blob_size);
- }
-
- if (trie_blob_size + 1 >= precompiled_chars_map->size())
- return util::InternalError("Blob for normalization rule is broken.");
+ blob.remove_prefix(sizeof(trie_blob_size));
- uint32 *data = reinterpret_cast<uint32 *>(
- const_cast<char *>(precompiled_chars_map->data()));
- for (int i = 0; i <= trie_blob_size; ++i) data[i] = swap32(data[i]);
+#ifdef __BIG_ENDIAN__
+ buffer->assign(blob.data(), trie_blob_size);
+ uint32 *data = reinterpret_cast<uint32 *>(const_cast<char *>(buffer->data()));
+ for (int i = 0; i < trie_blob_size / 4; ++i) data[i] = util::Swap32(data[i]);
+ *trie_blob = absl::string_view(buffer->data(), trie_blob_size);
+#else
+ *trie_blob = absl::string_view(blob.data(), trie_blob_size);
+#endif
-#endif // __BIG_ENDIAN__
+ blob.remove_prefix(trie_blob_size);
+ *normalized = absl::string_view(blob.data(), blob.size());
return util::OkStatus();
}
diff --git a/src/normalizer.h b/src/normalizer.h
index b198722..c31864d 100644
--- a/src/normalizer.h
+++ b/src/normalizer.h
@@ -95,10 +95,6 @@ class Normalizer {
friend class Builder;
- // Swap endian in `compiled_chars_map`. Only called big-endian machine.
- static util::Status MaybeSwapEndian(std::string *compiled_chars_map,
- uint32 trie_blob_size);
-
private:
FRIEND_TEST(NormalizerTest, EncodeDecodePrecompiledCharsMapTest);
@@ -126,7 +122,8 @@ class Normalizer {
// Decodes blob into trie_blob and normalized string.
static util::Status DecodePrecompiledCharsMap(absl::string_view blob,
absl::string_view *trie_blob,
- absl::string_view *normalized);
+ absl::string_view *normalized,
+ std::string *buffer = nullptr);
// Maximum size of the return value of Trie, which corresponds
// to the maximum size of shared common prefix in the chars map.
@@ -149,6 +146,11 @@ class Normalizer {
// "_hello" and "_world".
const bool treat_whitespace_as_suffix_ = false;
+#ifndef __BIG_ENDIAN__
+ // Stores the blob for TRIE encoded in big-endian.
+ std::string precompiled_charsmap_buffer_;
+#endif
+
// Normalizer's status.
util::Status status_;
};
diff --git a/src/normalizer_test.cc b/src/normalizer_test.cc
index bc2c657..585e8f4 100644
--- a/src/normalizer_test.cc
+++ b/src/normalizer_test.cc
@@ -12,10 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
+#include "normalizer.h"
+
#include <vector>
#include "builder.h"
-#include "normalizer.h"
#include "sentencepiece_trainer.h"
#include "testharness.h"
#include "util.h"
@@ -358,16 +359,17 @@ TEST(NormalizerTest, NormalizeFullTest) {
TEST(NormalizerTest, EncodeDecodePrecompiledCharsMapTest) {
const std::string blob = Normalizer::EncodePrecompiledCharsMap("foo", "bar");
+ std::string buf;
absl::string_view trie_blob, normalized_blob;
- EXPECT_TRUE(
- Normalizer::DecodePrecompiledCharsMap(blob, &trie_blob, &normalized_blob)
- .ok());
+ EXPECT_TRUE(Normalizer::DecodePrecompiledCharsMap(blob, &trie_blob,
+ &normalized_blob, &buf)
+ .ok());
EXPECT_EQ("foo", trie_blob);
EXPECT_EQ("bar", normalized_blob);
- EXPECT_FALSE(
- Normalizer::DecodePrecompiledCharsMap("", &trie_blob, &normalized_blob)
- .ok());
+ EXPECT_FALSE(Normalizer::DecodePrecompiledCharsMap("", &trie_blob,
+ &normalized_blob, &buf)
+ .ok());
}
TEST(NormalizerTest, StatusTest) {
diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc
index 751519f..765bc50 100644
--- a/src/sentencepiece_processor.cc
+++ b/src/sentencepiece_processor.cc
@@ -79,12 +79,6 @@ util::Status SentencePieceProcessor::Load(
model_proto_ = std::move(model_proto);
model_ = ModelFactory::Create(*model_proto_);
- if (!model_proto_->normalizer_spec().precompiled_charsmap().empty()) {
- RETURN_IF_ERROR(normalizer::Normalizer::MaybeSwapEndian(
- model_proto_->mutable_normalizer_spec()->mutable_precompiled_charsmap(),
- 0));
- }
-
normalizer_ = absl::make_unique<normalizer::Normalizer>(
model_proto_->normalizer_spec(), model_proto_->trainer_spec());
if (model_proto_->has_denormalizer_spec() &&
diff --git a/src/util.h b/src/util.h
index bf8a758..b390d4c 100644
--- a/src/util.h
+++ b/src/util.h
@@ -339,7 +339,7 @@ inline std::string JoinPath(absl::string_view path) {
}
template <typename... T>
-inline std::string JoinPath(absl::string_view first, const T &... rest) {
+inline std::string JoinPath(absl::string_view first, const T &...rest) {
#ifdef OS_WIN
return JoinPath(first) + "\\" + JoinPath(rest...);
#else
@@ -412,6 +412,8 @@ class StatusBuilder {
#define CHECK_GT_OR_RETURN(a, b) CHECK_OR_RETURN((a) > (b))
#define CHECK_LT_OR_RETURN(a, b) CHECK_OR_RETURN((a) < (b))
+inline uint32 Swap32(uint32 x) { return __builtin_bswap32(x); }
+
} // namespace util
namespace port {