diff options
author | Taku Kudo <taku@google.com> | 2020-05-20 07:45:49 +0300 |
---|---|---|
committer | Taku Kudo <taku@google.com> | 2020-05-20 07:45:49 +0300 |
commit | d48247191a6d50e469ed1a4a36e877befffd1851 (patch) | |
tree | 83b5aba87746aaa7bd5dbafe26ed2628e0bbab74 /src/sentencepiece_processor.cc | |
parent | b254e84528acdd1a2802d29922ae3496e8989be1 (diff) |
0.1.91 pre-release
Diffstat (limited to 'src/sentencepiece_processor.cc')
-rw-r--r-- | src/sentencepiece_processor.cc | 20 |
1 files changed, 11 insertions, 9 deletions
diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc index 4263a2f..a4dd575 100644 --- a/src/sentencepiece_processor.cc +++ b/src/sentencepiece_processor.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License.! -#include "sentencepiece_processor.h" - #include <map> #include <set> #include <utility> @@ -24,6 +22,7 @@ #include "model_factory.h" #include "model_interface.h" #include "normalizer.h" +#include "sentencepiece_processor.h" #include "third_party/absl/memory/memory.h" #include "third_party/absl/strings/numbers.h" #include "third_party/absl/strings/str_cat.h" @@ -446,6 +445,9 @@ util::Status SentencePieceProcessor::NBestEncode( std::vector<size_t> norm_to_orig; RETURN_IF_ERROR(normalizer_->Normalize(input, &normalized, &norm_to_orig)); + CHECK_OR_RETURN(model_->IsNBestEncodeAvailable()) + << "NBestEncode is not available for the current model."; + const auto nbests = model_->NBestEncode(normalized, nbest_size); CHECK_OR_RETURN(!nbests.empty()) << "NBestEncode returns empty result."; @@ -470,7 +472,13 @@ util::Status SentencePieceProcessor::SampleEncode( std::vector<size_t> norm_to_orig; RETURN_IF_ERROR(normalizer_->Normalize(input, &normalized, &norm_to_orig)); - if (nbest_size == 1 || nbest_size == 0) { + if (!model_->IsNBestEncodeAvailable() || nbest_size < 0) { + CHECK_OR_RETURN(model_->IsSampleEncodeAvailable()) + << "SampleEncode is not available for the current model."; + const auto result = model_->SampleEncode(normalized, alpha); + RETURN_IF_ERROR(PopulateSentencePieceText(input, normalized, norm_to_orig, + result, spt)); + } else if (nbest_size == 1 || nbest_size == 0) { const auto result = model_->Encode(normalized); RETURN_IF_ERROR(PopulateSentencePieceText(input, normalized, norm_to_orig, result, spt)); @@ -487,11 +495,6 @@ util::Status SentencePieceProcessor::SampleEncode( std::discrete_distribution<int> dist(probs.begin(), probs.end()); RETURN_IF_ERROR(PopulateSentencePieceText(input, normalized, norm_to_orig, nbests[dist(*mt)].first, spt)); - - } else if (nbest_size < 0) { - const auto result = model_->SampleEncode(normalized, alpha); - RETURN_IF_ERROR(PopulateSentencePieceText(input, normalized, norm_to_orig, - result, spt)); } return util::OkStatus(); @@ -828,6 +831,5 @@ util::Status SaveModelProto(absl::string_view filename, return util::OkStatus(); } - } // namespace io } // namespace sentencepiece |