diff options
author | Taku Kudo <taku@google.com> | 2020-10-13 18:50:51 +0300 |
---|---|---|
committer | Taku Kudo <taku@google.com> | 2020-10-13 18:50:51 +0300 |
commit | 4d14b86ca03b5dbd04909f7dea60ab8a7e0c2600 (patch) | |
tree | 71af79ecbdd920c4d5b33cae660b2abafc5a86b2 | |
parent | 8e70143bb3ba32c1bfac7d9ac8f3f0dfe51be61f (diff) |
support big-endian architecture
-rw-r--r-- | src/normalizer.cc | 62 | ||||
-rw-r--r-- | src/normalizer.h | 12 | ||||
-rw-r--r-- | src/normalizer_test.cc | 16 | ||||
-rw-r--r-- | src/sentencepiece_processor.cc | 6 | ||||
-rw-r--r-- | src/util.h | 4 |
5 files changed, 47 insertions, 53 deletions
diff --git a/src/normalizer.cc b/src/normalizer.cc index 8f10f12..7fe90a2 100644 --- a/src/normalizer.cc +++ b/src/normalizer.cc @@ -51,7 +51,12 @@ void Normalizer::Init() { LOG(INFO) << "precompiled_charsmap is empty. use identity normalization."; } else { absl::string_view trie_blob, normalized; +#ifdef __BIG_ENDIAN__ + status_ = DecodePrecompiledCharsMap(index, &trie_blob, &normalized, + &precompiled_charsmap_buffer_); +#else status_ = DecodePrecompiledCharsMap(index, &trie_blob, &normalized); +#endif if (!status_.ok()) return; // Reads the body of double array. @@ -259,7 +264,11 @@ std::string Normalizer::EncodePrecompiledCharsMap( blob.append(trie_blob.data(), trie_blob.size()); blob.append(normalized.data(), normalized.size()); - MaybeSwapEndian(&blob, trie_blob.size()).IgnoreError(); +#ifdef __BIG_ENDIAN__ + uint32 *data = reinterpret_cast<uint32 *>(const_cast<char *>(blob.data())); + for (int i = 0; i <= trie_blob.size() / 4; ++i) + data[i] = util::Swap32(data[i]); +#endif return blob; } @@ -267,51 +276,36 @@ std::string Normalizer::EncodePrecompiledCharsMap( // static util::Status Normalizer::DecodePrecompiledCharsMap( absl::string_view blob, absl::string_view *trie_blob, - absl::string_view *normalized) { + absl::string_view *normalized, std::string *buffer) { uint32 trie_blob_size = 0; + if (blob.size() <= sizeof(trie_blob_size) || !string_util::DecodePOD<uint32>( absl::string_view(blob.data(), sizeof(trie_blob_size)), - &trie_blob_size) || - trie_blob_size >= blob.size()) { + &trie_blob_size)) { return util::InternalError("Blob for normalization rule is broken."); } - blob.remove_prefix(sizeof(trie_blob_size)); - *trie_blob = absl::string_view(blob.data(), trie_blob_size); - - blob.remove_prefix(trie_blob_size); - *normalized = absl::string_view(blob.data(), blob.size()); - - return util::OkStatus(); -} - -util::Status Normalizer::MaybeSwapEndian(std::string *precompiled_chars_map, - uint32 trie_blob_size) { #ifdef __BIG_ENDIAN__ - auto swap32 = [](uint32 x) -> uint32 { return __builtin_bswap32(x); }; + trie_blob_size = util::Swap32(trie_blob_size); +#endif - auto blob = absl::string_view(precompiled_chars_map->data(), - precompiled_chars_map->size()); + if (trie_blob_size >= blob.size()) + return util::InternalError("Trie data size exceeds the input blob size."); - if (trie_blob_size == 0) { - if (blob.size() <= sizeof(trie_blob_size) || - !string_util::DecodePOD<uint32>( - absl::string_view(blob.data(), sizeof(trie_blob_size)), - &trie_blob_size)) { - return util::InternalError("Blob for normalization rule is broken."); - } - trie_blob_size = swap32(trie_blob_size); - } - - if (trie_blob_size + 1 >= precompiled_chars_map->size()) - return util::InternalError("Blob for normalization rule is broken."); + blob.remove_prefix(sizeof(trie_blob_size)); - uint32 *data = reinterpret_cast<uint32 *>( - const_cast<char *>(precompiled_chars_map->data())); - for (int i = 0; i <= trie_blob_size; ++i) data[i] = swap32(data[i]); +#ifdef __BIG_ENDIAN__ + buffer->assign(blob.data(), trie_blob_size); + uint32 *data = reinterpret_cast<uint32 *>(const_cast<char *>(buffer->data())); + for (int i = 0; i < trie_blob_size / 4; ++i) data[i] = util::Swap32(data[i]); + *trie_blob = absl::string_view(buffer->data(), trie_blob_size); +#else + *trie_blob = absl::string_view(blob.data(), trie_blob_size); +#endif -#endif // __BIG_ENDIAN__ + blob.remove_prefix(trie_blob_size); + *normalized = absl::string_view(blob.data(), blob.size()); return util::OkStatus(); } diff --git a/src/normalizer.h b/src/normalizer.h index b198722..c31864d 100644 --- a/src/normalizer.h +++ b/src/normalizer.h @@ -95,10 +95,6 @@ class Normalizer { friend class Builder; - // Swap endian in `compiled_chars_map`. Only called big-endian machine. - static util::Status MaybeSwapEndian(std::string *compiled_chars_map, - uint32 trie_blob_size); - private: FRIEND_TEST(NormalizerTest, EncodeDecodePrecompiledCharsMapTest); @@ -126,7 +122,8 @@ class Normalizer { // Decodes blob into trie_blob and normalized string. static util::Status DecodePrecompiledCharsMap(absl::string_view blob, absl::string_view *trie_blob, - absl::string_view *normalized); + absl::string_view *normalized, + std::string *buffer = nullptr); // Maximum size of the return value of Trie, which corresponds // to the maximum size of shared common prefix in the chars map. @@ -149,6 +146,11 @@ class Normalizer { // "_hello" and "_world". const bool treat_whitespace_as_suffix_ = false; +#ifndef __BIG_ENDIAN__ + // Stores the blob for TRIE encoded in big-endian. + std::string precompiled_charsmap_buffer_; +#endif + // Normalizer's status. util::Status status_; }; diff --git a/src/normalizer_test.cc b/src/normalizer_test.cc index bc2c657..585e8f4 100644 --- a/src/normalizer_test.cc +++ b/src/normalizer_test.cc @@ -12,10 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License.! +#include "normalizer.h" + #include <vector> #include "builder.h" -#include "normalizer.h" #include "sentencepiece_trainer.h" #include "testharness.h" #include "util.h" @@ -358,16 +359,17 @@ TEST(NormalizerTest, NormalizeFullTest) { TEST(NormalizerTest, EncodeDecodePrecompiledCharsMapTest) { const std::string blob = Normalizer::EncodePrecompiledCharsMap("foo", "bar"); + std::string buf; absl::string_view trie_blob, normalized_blob; - EXPECT_TRUE( - Normalizer::DecodePrecompiledCharsMap(blob, &trie_blob, &normalized_blob) - .ok()); + EXPECT_TRUE(Normalizer::DecodePrecompiledCharsMap(blob, &trie_blob, + &normalized_blob, &buf) + .ok()); EXPECT_EQ("foo", trie_blob); EXPECT_EQ("bar", normalized_blob); - EXPECT_FALSE( - Normalizer::DecodePrecompiledCharsMap("", &trie_blob, &normalized_blob) - .ok()); + EXPECT_FALSE(Normalizer::DecodePrecompiledCharsMap("", &trie_blob, + &normalized_blob, &buf) + .ok()); } TEST(NormalizerTest, StatusTest) { diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc index 751519f..765bc50 100644 --- a/src/sentencepiece_processor.cc +++ b/src/sentencepiece_processor.cc @@ -79,12 +79,6 @@ util::Status SentencePieceProcessor::Load( model_proto_ = std::move(model_proto); model_ = ModelFactory::Create(*model_proto_); - if (!model_proto_->normalizer_spec().precompiled_charsmap().empty()) { - RETURN_IF_ERROR(normalizer::Normalizer::MaybeSwapEndian( - model_proto_->mutable_normalizer_spec()->mutable_precompiled_charsmap(), - 0)); - } - normalizer_ = absl::make_unique<normalizer::Normalizer>( model_proto_->normalizer_spec(), model_proto_->trainer_spec()); if (model_proto_->has_denormalizer_spec() && @@ -339,7 +339,7 @@ inline std::string JoinPath(absl::string_view path) { } template <typename... T> -inline std::string JoinPath(absl::string_view first, const T &... rest) { +inline std::string JoinPath(absl::string_view first, const T &...rest) { #ifdef OS_WIN return JoinPath(first) + "\\" + JoinPath(rest...); #else @@ -412,6 +412,8 @@ class StatusBuilder { #define CHECK_GT_OR_RETURN(a, b) CHECK_OR_RETURN((a) > (b)) #define CHECK_LT_OR_RETURN(a, b) CHECK_OR_RETURN((a) < (b)) +inline uint32 Swap32(uint32 x) { return __builtin_bswap32(x); } + } // namespace util namespace port { |