Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2020-05-20 07:45:49 +0300
committerTaku Kudo <taku@google.com>2020-05-20 07:45:49 +0300
commitd48247191a6d50e469ed1a4a36e877befffd1851 (patch)
tree83b5aba87746aaa7bd5dbafe26ed2628e0bbab74 /src/sentencepiece_processor.cc
parentb254e84528acdd1a2802d29922ae3496e8989be1 (diff)
0.1.91 pre-release
Diffstat (limited to 'src/sentencepiece_processor.cc')
-rw-r--r--src/sentencepiece_processor.cc20
1 files changed, 11 insertions, 9 deletions
diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc
index 4263a2f..a4dd575 100644
--- a/src/sentencepiece_processor.cc
+++ b/src/sentencepiece_processor.cc
@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
-#include "sentencepiece_processor.h"
-
#include <map>
#include <set>
#include <utility>
@@ -24,6 +22,7 @@
#include "model_factory.h"
#include "model_interface.h"
#include "normalizer.h"
+#include "sentencepiece_processor.h"
#include "third_party/absl/memory/memory.h"
#include "third_party/absl/strings/numbers.h"
#include "third_party/absl/strings/str_cat.h"
@@ -446,6 +445,9 @@ util::Status SentencePieceProcessor::NBestEncode(
std::vector<size_t> norm_to_orig;
RETURN_IF_ERROR(normalizer_->Normalize(input, &normalized, &norm_to_orig));
+ CHECK_OR_RETURN(model_->IsNBestEncodeAvailable())
+ << "NBestEncode is not available for the current model.";
+
const auto nbests = model_->NBestEncode(normalized, nbest_size);
CHECK_OR_RETURN(!nbests.empty()) << "NBestEncode returns empty result.";
@@ -470,7 +472,13 @@ util::Status SentencePieceProcessor::SampleEncode(
std::vector<size_t> norm_to_orig;
RETURN_IF_ERROR(normalizer_->Normalize(input, &normalized, &norm_to_orig));
- if (nbest_size == 1 || nbest_size == 0) {
+ if (!model_->IsNBestEncodeAvailable() || nbest_size < 0) {
+ CHECK_OR_RETURN(model_->IsSampleEncodeAvailable())
+ << "SampleEncode is not available for the current model.";
+ const auto result = model_->SampleEncode(normalized, alpha);
+ RETURN_IF_ERROR(PopulateSentencePieceText(input, normalized, norm_to_orig,
+ result, spt));
+ } else if (nbest_size == 1 || nbest_size == 0) {
const auto result = model_->Encode(normalized);
RETURN_IF_ERROR(PopulateSentencePieceText(input, normalized, norm_to_orig,
result, spt));
@@ -487,11 +495,6 @@ util::Status SentencePieceProcessor::SampleEncode(
std::discrete_distribution<int> dist(probs.begin(), probs.end());
RETURN_IF_ERROR(PopulateSentencePieceText(input, normalized, norm_to_orig,
nbests[dist(*mt)].first, spt));
-
- } else if (nbest_size < 0) {
- const auto result = model_->SampleEncode(normalized, alpha);
- RETURN_IF_ERROR(PopulateSentencePieceText(input, normalized, norm_to_orig,
- result, spt));
}
return util::OkStatus();
@@ -828,6 +831,5 @@ util::Status SaveModelProto(absl::string_view filename,
return util::OkStatus();
}
-
} // namespace io
} // namespace sentencepiece