Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'python/src/sentencepiece/__init__.py')
-rw-r--r--python/src/sentencepiece/__init__.py38
1 files changed, 22 insertions, 16 deletions
diff --git a/python/src/sentencepiece/__init__.py b/python/src/sentencepiece/__init__.py
index e704a2a..001ffc7 100644
--- a/python/src/sentencepiece/__init__.py
+++ b/python/src/sentencepiece/__init__.py
@@ -116,9 +116,6 @@ class SentencePieceProcessor(object):
def DecodePieces(self, pieces):
return _sentencepiece.SentencePieceProcessor_DecodePieces(self, pieces)
- def DecodeIds(self, ids):
- return _sentencepiece.SentencePieceProcessor_DecodeIds(self, ids)
-
def EncodeAsSerializedProto(self, input):
return _sentencepiece.SentencePieceProcessor_EncodeAsSerializedProto(self, input)
@@ -131,9 +128,6 @@ class SentencePieceProcessor(object):
def DecodePiecesAsSerializedProto(self, pieces):
return _sentencepiece.SentencePieceProcessor_DecodePiecesAsSerializedProto(self, pieces)
- def DecodeIdsAsSerializedProto(self, ids):
- return _sentencepiece.SentencePieceProcessor_DecodeIdsAsSerializedProto(self, ids)
-
def GetPieceSize(self):
return _sentencepiece.SentencePieceProcessor_GetPieceSize(self)
@@ -176,6 +170,12 @@ class SentencePieceProcessor(object):
def LoadFromFile(self, arg):
return _sentencepiece.SentencePieceProcessor_LoadFromFile(self, arg)
+ def DecodeIdsWithCheck(self, ids):
+ return _sentencepiece.SentencePieceProcessor_DecodeIdsWithCheck(self, ids)
+
+ def DecodeIdsAsSerializedProtoWithCheck(self, ids):
+ return _sentencepiece.SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck(self, ids)
+
def Init(self,
model_file=None,
model_proto=None,
@@ -242,8 +242,8 @@ class SentencePieceProcessor(object):
nbest_size < 0: assuming that nbest_size is infinite and samples
from the all hypothesis (lattice) using
forward-filtering-and-backward-sampling algorithm.
- alpha: Soothing parameter for unigram sampling, and dropout probability of
- merge operations for BPE-dropout.
+ alpha: Soothing parameter for unigram sampling, and merge probability for
+ BPE-dropout (probablity 'p' in BPE-dropout paper).
"""
if out_type is None:
@@ -262,12 +262,12 @@ class SentencePieceProcessor(object):
alpha = self._alpha
if enable_sampling == True and (nbest_size is None or nbest_size == 0 or
- nbest_size == 1 or alpha is None or
- alpha <= 0.0 or alpha > 1.0):
+ nbest_size == 1 or alpha is None):
raise RuntimeError(
'When enable_sampling is True, We must specify "nbest_size > 1" or "nbest_size = -1", '
- 'and "0.0 < alpha < 1.0". "nbest_size = -1" is enabled only on unigram mode and '
- 'samples from all candidates on the lattice instead of nbest segmentations. '
+ 'and "alpha". "nbest_size" is enabled only on unigram mode ignored in BPE-dropout. '
+ 'when "nbest_size = -1" , this method samples from all candidates on the lattice '
+ 'instead of nbest segmentations.'
)
def _encode(text):
@@ -310,7 +310,7 @@ class SentencePieceProcessor(object):
if not input:
return self.DecodeIds([])
elif type(input) is int:
- return self.DecodeIds([input])
+ return self.DecodeIdsWithCheck([input])
elif type(input) is str:
return self.DecodePieces([input])
@@ -318,7 +318,7 @@ class SentencePieceProcessor(object):
if not input:
return self.DecodeIds([])
if type(input[0]) is int:
- return self.DecodeIds(input)
+ return self.DecodeIdsWithCheck(input)
return self.DecodePieces(input)
if type(input[0]) is list:
@@ -486,12 +486,16 @@ def _add_snake_case(classname):
def _batchnize(classname, name):
"""Enables batch request for the method classname.name."""
func = getattr(classname, name, None)
+ def _func(v, n):
+ if type(n) is int and (n < 0 or n >= v.piece_size()):
+ raise IndexError('piece id is out of range.')
+ return func(v, n)
def _batched_func(self, arg):
if type(arg) is list:
- return [func(self, n) for n in arg]
+ return [_func(self, n) for n in arg]
else:
- return func(self, arg)
+ return _func(self, arg)
setattr(classname, name, _batched_func)
@@ -501,6 +505,8 @@ setattr(SentencePieceProcessor, '__init__', SentencePieceProcessor.Init)
SentencePieceProcessor.Tokenize = SentencePieceProcessor.Encode
SentencePieceProcessor.Detokenize = SentencePieceProcessor.Decode
+SentencePieceProcessor.DecodeIds = SentencePieceProcessor.DecodeIdsWithCheck
+SentencePieceProcessor.DecodeIdsAsSerializedProto = SentencePieceProcessor.DecodeIdsAsSerializedProtoWithCheck
for m in [
'PieceToId', 'IdToPiece', 'GetScore', 'IsUnknown', 'IsControl', 'IsUnused',