Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2018-06-17 10:39:40 +0300
committerTaku Kudo <taku@google.com>2018-06-17 10:39:40 +0300
commit74a7e18077ce32dae9033124514975da2d54da80 (patch)
tree95d4de4cafd2cfc2119b07c2ccad87bd5439b814 /src
parent0a30c9bac16369b4463dbbace7163816416e2094 (diff)
Minor fixes
Diffstat (limited to 'src')
-rw-r--r--src/sentencepiece_processor.cc26
-rw-r--r--src/sentencepiece_processor_test.cc2
2 files changed, 11 insertions, 17 deletions
diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc
index 88c3266..ca7d548 100644
--- a/src/sentencepiece_processor.cc
+++ b/src/sentencepiece_processor.cc
@@ -299,15 +299,13 @@ util::Status SentencePieceProcessor::PopulateSentencePieceText(
} else {
const size_t begin = consumed;
const size_t end = consumed + w.size();
- if (begin >= norm_to_orig.size() || end >= norm_to_orig.size()) {
- return util::OutOfRangeError("consumed index is out-of-range.");
- }
+ CHECK_LT_OR_RETURN(begin, norm_to_orig.size());
+ CHECK_LT_OR_RETURN(end, norm_to_orig.size());
const size_t orig_begin = norm_to_orig[begin];
const size_t orig_end = norm_to_orig[end];
- if (orig_begin > input.size() || orig_end > input.size() ||
- orig_begin > orig_end) {
- return util::OutOfRangeError("original index is out-of-range.");
- }
+ CHECK_LE_OR_RETURN(orig_begin, input.size());
+ CHECK_LE_OR_RETURN(orig_end, input.size());
+ CHECK_LE_OR_RETURN(orig_begin, orig_end);
const auto surface = input.substr(orig_begin, orig_end - orig_begin);
// Merges continuous run of unknown pieces so that decoder
// can copy or generate unknown tokens easily.
@@ -331,16 +329,15 @@ util::Status SentencePieceProcessor::PopulateSentencePieceText(
is_prev_unk = is_unk;
}
- if (consumed != normalized.size()) {
- return util::OutOfRangeError("all normalized characters are not consumed.");
- }
+ CHECK_EQ_OR_RETURN(consumed, normalized.size())
+ << "all normalized characters are not consumed.";
RETURN_IF_ERROR(ApplyExtraOptions(encode_extra_options_, spt));
spt->set_text(input);
return util::OkStatus();
-}
+} // namespace sentencepiece
util::Status SentencePieceProcessor::Encode(const std::string &input,
SentencePieceText *spt) const {
@@ -384,16 +381,13 @@ util::Status SentencePieceProcessor::SampleEncode(
SentencePieceText *spt) const {
CHECK_OR_RETURN_STATUS_PROTO(spt);
- if (nbest_size > 512 || nbest_size == 0) {
- return util::OutOfRangeError(
- "nbest_size must be 0 < nbest_size <= 512 or nbest_size < 0.");
- }
+ CHECK_LE_OR_RETURN(nbest_size, 512) << "nbest_size must be nbest_size <= 512";
std::string normalized;
std::vector<size_t> norm_to_orig;
RETURN_IF_ERROR(normalizer_->Normalize(input, &normalized, &norm_to_orig));
- if (nbest_size == 1) {
+ if (nbest_size == 1 || nbest_size == 0) {
const auto result = model_->Encode(normalized);
RETURN_IF_ERROR(PopulateSentencePieceText(input, normalized, norm_to_orig,
result, spt));
diff --git a/src/sentencepiece_processor_test.cc b/src/sentencepiece_processor_test.cc
index 51a83db..7d77a33 100644
--- a/src/sentencepiece_processor_test.cc
+++ b/src/sentencepiece_processor_test.cc
@@ -408,7 +408,7 @@ TEST(SentencepieceProcessorTest, SampleEncodeTest) {
}
EXPECT_NOT_OK(sp.SampleEncode("ABC DEF", 1024, 0.5, &output));
- EXPECT_NOT_OK(sp.SampleEncode("ABC DEF", 0, 0.5, &output));
+ EXPECT_OK(sp.SampleEncode("ABC DEF", 0, 0.5, &output));
EXPECT_OK(sp.SampleEncode("ABC DEF", 1, 0.5, &output));
std::vector<int> freq(2, 0);