Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2020-05-18 05:18:57 +0300
committerTaku Kudo <taku@google.com>2020-05-18 05:18:57 +0300
commitfbaf2f9b9b59cdd850b797300ef6b9b9ac3fe0af (patch)
tree946f3bc494e7bab85e979b15b24eb07da9864a6f
parent43a504eeead2b3226c3675c10b46edfba804f85f (diff)
supported pickle serialization
-rw-r--r--python/.gitignore1
-rw-r--r--python/sentencepiece.i501
-rw-r--r--python/sentencepiece.py522
-rw-r--r--python/sentencepiece_wrap.cxx188
-rwxr-xr-xpython/test/sentencepiece_test.py14
5 files changed, 564 insertions, 662 deletions
diff --git a/python/.gitignore b/python/.gitignore
index e4b2fd9..78e215b 100644
--- a/python/.gitignore
+++ b/python/.gitignore
@@ -1,2 +1,3 @@
/*.so
/build
+/*.pickle
diff --git a/python/sentencepiece.i b/python/sentencepiece.i
index 28e178a..e909506 100644
--- a/python/sentencepiece.i
+++ b/python/sentencepiece.i
@@ -171,15 +171,12 @@ class PySentenceIterator : public sentencepiece::SentenceIterator {
%ignore sentencepiece::SentencePieceProcessor::status;
%ignore sentencepiece::SentencePieceProcessor::Encode;
-%ignore sentencepiece::SentencePieceProcessor::Encode;
%ignore sentencepiece::SentencePieceProcessor::SampleEncode;
%ignore sentencepiece::SentencePieceProcessor::NBestEncode;
%ignore sentencepiece::SentencePieceProcessor::Decode;
%ignore sentencepiece::SentencePieceProcessor::model_proto;
-%ignore sentencepiece::SentencePieceProcessor::Load(std::istream *);
-%ignore sentencepiece::SentencePieceProcessor::LoadOrDie(std::istream *);
-%ignore sentencepiece::SentencePieceProcessor::Load(const ModelProto &);
-%ignore sentencepiece::SentencePieceProcessor::Load(std::unique_ptr<ModelProto>);
+%ignore sentencepiece::SentencePieceProcessor::Load;
+%ignore sentencepiece::SentencePieceProcessor::LoadOrDie;
%ignore sentencepiece::pretokenizer::PretokenizerForTrainingInterface;
%ignore sentencepiece::SentenceIterator;
%ignore sentencepiece::SentencePieceTrainer::Train;
@@ -193,50 +190,284 @@ class PySentenceIterator : public sentencepiece::SentenceIterator {
%ignore sentencepiece::SentencePieceTrainer::GetPretokenizerForTraining;
%extend sentencepiece::SentencePieceProcessor {
-
- int __len__() {
- return $self->GetPieceSize();
+ sentencepiece::util::Status LoadFromFile(absl::string_view arg) {
+ return $self->Load(arg);
}
- int __getitem__(absl::string_view key) const {
- return $self->PieceToId(key);
- }
+%pythoncode {
+ def Init(self,
+ model_file=None,
+ model_proto=None,
+ out_type=int,
+ add_bos=False,
+ add_eos=False,
+ reverse=False,
+ enable_sampling=False,
+ nbest_size=-1,
+ alpha=0.1):
+ """Initialzie sentencepieceProcessor.
+
+ Args:
+ model_file: The sentencepiece model file path.
+ model_proto: The sentencepiece model serialized proto.
+ out_type: output type. int or str.
+ add_bos: Add <s> to the result (Default = false)
+ add_eos: Add </s> to the result (Default = false) <s>/</s> is added after
+ reversing (if enabled).
+ reverse: Reverses the tokenized sequence (Default = false)
+ nbest_size: sampling parameters for unigram. Invalid for BPE-Dropout.
+ nbest_size = {0,1}: No sampling is performed.
+ nbest_size > 1: samples from the nbest_size results.
+ nbest_size < 0: assuming that nbest_size is infinite and samples
+ from the all hypothesis (lattice) using
+ forward-filtering-and-backward-sampling algorithm.
+ alpha: Soothing parameter for unigram sampling, and merge probability for
+ BPE-dropout.
+ """
+
+ _sentencepiece_processor_init_native(self)
+ self._out_type = out_type
+ self._add_bos = add_bos
+ self._add_eos = add_eos
+ self._reverse = reverse
+ self._enable_sampling = enable_sampling
+ self._nbest_size = nbest_size
+ self._alpha = alpha
+ if model_file or model_proto:
+ self.Load(model_file=model_file, model_proto=model_proto)
+
+
+ def Encode(self,
+ input,
+ out_type=None,
+ add_bos=None,
+ add_eos=None,
+ reverse=None,
+ enable_sampling=None,
+ nbest_size=None,
+ alpha=None):
+ """Encode text input to segmented ids or tokens.
+
+ Args:
+ input: input string. accepsts list of string.
+ out_type: output type. int or str.
+ add_bos: Add <s> to the result (Default = false)
+ add_eos: Add </s> to the result (Default = false) <s>/</s> is added after
+ reversing (if enabled).
+ reverse: Reverses the tokenized sequence (Default = false)
+ nbest_size: sampling parameters for unigram. Invalid for BPE-Dropout.
+ nbest_size = {0,1}: No sampling is performed.
+ nbest_size > 1: samples from the nbest_size results.
+ nbest_size < 0: assuming that nbest_size is infinite and samples
+ from the all hypothesis (lattice) using
+ forward-filtering-and-backward-sampling algorithm.
+ alpha: Soothing parameter for unigram sampling, and merge probability for
+ BPE-dropout.
+ """
+
+ if out_type is None:
+ out_type = self._out_type
+ if add_bos is None:
+ add_bos = self._add_bos
+ if add_eos is None:
+ add_eos = self._add_eos
+ if reverse is None:
+ reverse = self._reverse
+ if enable_sampling is None:
+ enable_sampling = self._enable_sampling
+ if nbest_size is None:
+ nbest_size = self._nbest_size
+ if alpha is None:
+ alpha = self._alpha
+
+ if enable_sampling == True and (nbest_size is None or nbest_size == 0 or
+ nbest_size == 1 or alpha is None or
+ alpha <= 0.0 or alpha > 1.0):
+ raise RuntimeError(
+ 'When enable_sampling is True, We must specify "nbest_size > 1" or "nbest_size = -1", '
+ 'and "0.0 < alpha < 1.0". "nbest_size = -1" is enabled only on unigram mode and '
+ 'samples from all candidates on the lattice instead of nbest segmentations. '
+ )
+
+ def _encode(text):
+ if out_type is int:
+ if enable_sampling:
+ result = self.SampleEncodeAsIds(text, nbest_size, alpha)
+ else:
+ result = self.EncodeAsIds(text)
+ else:
+ if enable_sampling:
+ result = self.SampleEncodeAsPieces(text, nbest_size, alpha)
+ else:
+ result = self.EncodeAsPieces(text)
+
+ if reverse:
+ result.reverse()
+ if add_bos:
+ if out_type is int:
+ result = [self.bos_id()] + result
+ else:
+ result = [self.IdToPiece(self.bos_id())] + result
+
+ if add_eos:
+ if out_type is int:
+ result = result + [self.eos_id()]
+ else:
+ result = result + [self.IdToPiece(self.eos_id())]
+
+ return result
+
+ if type(input) is list:
+ return [_encode(n) for n in input]
+
+ return _encode(input)
+
+
+ def Decode(self, input):
+ """Decode processed id or token sequences."""
+
+ if not input:
+ return self.DecodeIds([])
+ elif type(input) is int:
+ return self.DecodeIds([input])
+ elif type(input) is str:
+ return self.DecodePieces([input])
+
+ def _decode(input):
+ if not input:
+ return self.DecodeIds([])
+ if type(input[0]) is int:
+ return self.DecodeIds(input)
+ return self.DecodePieces(input)
+
+ if type(input[0]) is list:
+ return [_decode(n) for n in input]
+
+ return _decode(input)
+
+
+ def piece_size(self):
+ return self.GetPieceSize()
+
+
+ def vocab_size(self):
+ return self.GetPieceSize()
+
+
+ def __getstate__(self):
+ return self.serialized_model_proto()
+
+
+ def __setstate__(self, serialized_model_proto):
+ self.__init__()
+ self.LoadFromSerializedProto(serialized_model_proto)
+
+
+ def __len__(self):
+ return self.GetPieceSize()
+
+
+ def __getitem__(self, piece):
+ return self.PieceToId(piece)
+
+
+ def Load(self, model_file=None, model_proto=None):
+ """Overwride SentencePieceProcessor.Load to support both model_file and model_proto.
+
+ Args:
+ model_file: The sentencepiece model file path.
+ model_proto: The sentencepiece model serialized proto. Either `model_file`
+ or `model_proto` must be set.
+ """
+ if model_file and model_proto:
+ raise RuntimeError('model_file and model_proto must be exclusive.')
+ if model_proto:
+ return self.LoadFromSerializedProto(model_proto)
+ return self.LoadFromFile(model_file)
+}
}
%extend sentencepiece::SentencePieceTrainer {
- static void TrainFromString(absl::string_view arg) {
+ static void _TrainFromString(absl::string_view arg) {
const auto _status = sentencepiece::SentencePieceTrainer::Train(arg);
if (!_status.ok()) throw _status;
return;
}
- static void TrainFromMap(const std::map<std::string, std::string> &args) {
+ static void _TrainFromMap(const std::map<std::string, std::string> &args) {
const auto _status = sentencepiece::SentencePieceTrainer::Train(args);
if (!_status.ok()) throw _status;
return;
}
- static void TrainFromMap2(const std::map<std::string, std::string> &args,
+ static void _TrainFromMap2(const std::map<std::string, std::string> &args,
SentenceIterator *iter) {
const auto _status = sentencepiece::SentencePieceTrainer::Train(args, iter);
if (!_status.ok()) throw _status;
return;
}
- static sentencepiece::util::bytes TrainFromMap3(const std::map<std::string, std::string> &args) {
+ static sentencepiece::util::bytes _TrainFromMap3(const std::map<std::string, std::string> &args) {
sentencepiece::util::bytes model_proto;
const auto _status = sentencepiece::SentencePieceTrainer::Train(args, nullptr, &model_proto);
if (!_status.ok()) throw _status;
return model_proto;
}
- static sentencepiece::util::bytes TrainFromMap4(const std::map<std::string, std::string> &args,
+ static sentencepiece::util::bytes _TrainFromMap4(const std::map<std::string, std::string> &args,
SentenceIterator *iter) {
sentencepiece::util::bytes model_proto;
const auto _status = sentencepiece::SentencePieceTrainer::Train(args, iter, &model_proto);
if (!_status.ok()) throw _status;
return model_proto;
}
+
+%pythoncode {
+ @staticmethod
+ def Train(arg=None, **kwargs):
+ """Train Sentencepiece model. Accept both kwargs and legacy string arg."""
+ if arg is not None and type(arg) is str:
+ return SentencePieceTrainer._TrainFromString(arg)
+
+ def _encode(value):
+ """Encode value to CSV.."""
+ if type(value) is list:
+ if sys.version_info[0] == 3:
+ f = StringIO()
+ else:
+ f = BytesIO()
+ writer = csv.writer(f, lineterminator='')
+ writer.writerow([str(v) for v in value])
+ return f.getvalue()
+ else:
+ return str(value)
+
+ sentence_iterator = None
+ model_writer = None
+ new_kwargs = {}
+ for key, value in kwargs.items():
+ if key in ['sentence_iterator', 'sentence_reader']:
+ sentence_iterator = value
+ elif key in ['model_writer']:
+ model_writer = value
+ else:
+ new_kwargs[key] = _encode(value)
+
+ if model_writer:
+ if sentence_iterator:
+ model_proto = SentencePieceTrainer._TrainFromMap4(new_kwargs,
+ sentence_iterator)
+ else:
+ model_proto = SentencePieceTrainer._TrainFromMap3(new_kwargs)
+ model_writer.write(model_proto)
+ else:
+ if sentence_iterator:
+ return SentencePieceTrainer._TrainFromMap2(new_kwargs, sentence_iterator)
+ else:
+ return SentencePieceTrainer._TrainFromMap(new_kwargs)
+
+ return None
+}
}
%typemap(out) std::vector<int> {
@@ -439,227 +670,6 @@ from io import StringIO
from io import BytesIO
-def _sentencepiece_processor_init(self,
- model_file=None,
- model_proto=None,
- out_type=int,
- add_bos=False,
- add_eos=False,
- reverse=False,
- enable_sampling=False,
- nbest_size=-1,
- alpha=0.1):
- """Overwride SentencePieceProcessor.__init__ to add addtional parameters.
-
- Args:
- model_file: The sentencepiece model file path.
- model_proto: The sentencepiece model serialized proto.
- out_type: output type. int or str.
- add_bos: Add <s> to the result (Default = false)
- add_eos: Add </s> to the result (Default = false) <s>/</s> is added after
- reversing (if enabled).
- reverse: Reverses the tokenized sequence (Default = false)
- nbest_size: sampling parameters for unigram. Invalid for BPE-Dropout.
- nbest_size = {0,1}: No sampling is performed.
- nbest_size > 1: samples from the nbest_size results.
- nbest_size < 0: assuming that nbest_size is infinite and samples
- from the all hypothesis (lattice) using
- forward-filtering-and-backward-sampling algorithm.
- alpha: Soothing parameter for unigram sampling, and merge probability for
- BPE-dropout.
- """
-
- _sentencepiece_processor_init_native(self)
- self._out_type = out_type
- self._add_bos = add_bos
- self._add_eos = add_eos
- self._reverse = reverse
- self._enable_sampling = enable_sampling
- self._nbest_size = nbest_size
- self._alpha = alpha
- if model_file or model_proto:
- self.Load(model_file=model_file, model_proto=model_proto)
-
-
-def _sentencepiece_processor_load(self, model_file=None, model_proto=None):
- """Overwride SentencePieceProcessor.Load to support both model_file and model_proto.
-
- Args:
- model_file: The sentencepiece model file path.
- model_proto: The sentencepiece model serialized proto. Either `model_file`
- or `model_proto` must be set.
- """
- if model_file and model_proto:
- raise RuntimeError('model_file and model_proto must be exclusive.')
- if model_proto:
- return self._LoadFromSerializedProto_native(model_proto)
- return self._Load_native(model_file)
-
-
-def _sentencepiece_processor_encode(self,
- input,
- out_type=None,
- add_bos=None,
- add_eos=None,
- reverse=None,
- enable_sampling=None,
- nbest_size=None,
- alpha=None):
- """Encode text input to segmented ids or tokens.
-
- Args:
- input: input string. accepsts list of string.
- out_type: output type. int or str.
- add_bos: Add <s> to the result (Default = false)
- add_eos: Add </s> to the result (Default = false) <s>/</s> is added after
- reversing (if enabled).
- reverse: Reverses the tokenized sequence (Default = false)
- nbest_size: sampling parameters for unigram. Invalid for BPE-Dropout.
- nbest_size = {0,1}: No sampling is performed.
- nbest_size > 1: samples from the nbest_size results.
- nbest_size < 0: assuming that nbest_size is infinite and samples
- from the all hypothesis (lattice) using
- forward-filtering-and-backward-sampling algorithm.
- alpha: Soothing parameter for unigram sampling, and merge probability for
- BPE-dropout.
- """
-
- if out_type is None:
- out_type = self._out_type
- if add_bos is None:
- add_bos = self._add_bos
- if add_eos is None:
- add_eos = self._add_eos
- if reverse is None:
- reverse = self._reverse
- if enable_sampling is None:
- enable_sampling = self._enable_sampling
- if nbest_size is None:
- nbest_size = self._nbest_size
- if alpha is None:
- alpha = self._alpha
-
- if enable_sampling == True and (nbest_size is None or nbest_size == 0 or
- nbest_size == 1 or alpha is None or
- alpha <= 0.0 or alpha > 1.0):
- raise RuntimeError(
- 'When enable_sampling is True, We must specify "nbest_size > 1" or "nbest_size = -1", '
- 'and "0.0 < alpha < 1.0". "nbest_size = -1" is enabled only on unigram mode and '
- 'samples from all candidates on the lattice instead of nbest segmentations. '
- )
-
- def _encode(text):
- if out_type is int:
- if enable_sampling:
- result = self.SampleEncodeAsIds(text, nbest_size, alpha)
- else:
- result = self.EncodeAsIds(text)
- else:
- if enable_sampling:
- result = self.SampleEncodeAsPieces(text, nbest_size, alpha)
- else:
- result = self.EncodeAsPieces(text)
-
- if reverse:
- result.reverse()
- if add_bos:
- if out_type is int:
- result = [self.bos_id()] + result
- else:
- result = [self.IdToPiece(self.bos_id())] + result
-
- if add_eos:
- if out_type is int:
- result = result + [self.eos_id()]
- else:
- result = result + [self.IdToPiece(self.eos_id())]
-
- return result
-
- if type(input) is list:
- return [_encode(n) for n in input]
-
- return _encode(input)
-
-
-def _sentencepiece_processor_decode(self, input):
- """Decode processed id or token sequences."""
-
- if not input:
- return self.DecodeIds([])
- elif type(input) is int:
- return self.DecodeIds([input])
- elif type(input) is str:
- return self.DecodePieces([input])
-
- def _decode(input):
- if not input:
- return self.DecodeIds([])
- if type(input[0]) is int:
- return self.DecodeIds(input)
- return self.DecodePieces(input)
-
- if type(input[0]) is list:
- return [_decode(n) for n in input]
-
- return _decode(input)
-
-def _sentencepiece_trainer_train(arg=None, **kwargs):
- """Train Sentencepiece model. Accept both kwargs and legacy string arg."""
- if arg is not None and type(arg) is str:
- return SentencePieceTrainer.TrainFromString(arg)
-
- def _encode(value):
- """Encode value to CSV.."""
- if type(value) is list:
- if sys.version_info[0] == 3:
- f = StringIO()
- else:
- f = BytesIO()
- writer = csv.writer(f, lineterminator='')
- writer.writerow([str(v) for v in value])
- return f.getvalue()
- else:
- return str(value)
-
- sentence_iterator = None
- model_writer = None
- new_kwargs = {}
- for key, value in kwargs.items():
- if key in ['sentence_iterator', 'sentence_reader']:
- sentence_iterator = value
- elif key in ['model_writer']:
- model_writer = value
- else:
- new_kwargs[key] = _encode(value)
-
- if model_writer:
- if sentence_iterator:
- model_proto = SentencePieceTrainer.TrainFromMap4(new_kwargs,
- sentence_iterator)
- else:
- model_proto = SentencePieceTrainer.TrainFromMap3(new_kwargs)
- model_writer.write(model_proto)
- else:
- if sentence_iterator:
- return SentencePieceTrainer.TrainFromMap2(new_kwargs, sentence_iterator)
- else:
- return SentencePieceTrainer.TrainFromMap(new_kwargs)
-
- return None
-
-
-def _save_native(classname):
- """Stores the origina method as _{method_name}_native."""
-
- native_map = {}
- for name, method in classname.__dict__.items():
- if name[0] != '_':
- native_map[('_%s_native' % name)] = method
- for k, v in native_map.items():
- setattr(classname, k, v)
-
-
def _add_snake_case(classname):
"""Added snake_cased method from CammelCased method."""
@@ -675,8 +685,7 @@ def _add_snake_case(classname):
def _batchnize(classname, name):
"""Enables batch request for the method classname.name."""
-
- func = getattr(classname, '_%s_native' % name, None)
+ func = getattr(classname, name, None)
def _batched_func(self, arg):
if type(arg) is list:
@@ -687,17 +696,11 @@ def _batchnize(classname, name):
setattr(classname, name, _batched_func)
-_save_native(SentencePieceProcessor)
_sentencepiece_processor_init_native = SentencePieceProcessor.__init__
-setattr(SentencePieceProcessor, 'Encode', _sentencepiece_processor_encode)
-setattr(SentencePieceProcessor, 'Tokenize', _sentencepiece_processor_encode)
-setattr(SentencePieceProcessor, 'Decode', _sentencepiece_processor_decode)
-setattr(SentencePieceProcessor, 'Detokenize', _sentencepiece_processor_decode)
-setattr(SentencePieceProcessor, 'Load', _sentencepiece_processor_load)
-setattr(SentencePieceProcessor, '__init__', _sentencepiece_processor_init)
-setattr(SentencePieceProcessor, 'vocab_size', SentencePieceProcessor.GetPieceSize)
-setattr(SentencePieceProcessor, 'piece_size', SentencePieceProcessor.GetPieceSize)
-setattr(SentencePieceTrainer, 'Train', staticmethod(_sentencepiece_trainer_train))
+setattr(SentencePieceProcessor, '__init__', SentencePieceProcessor.Init)
+
+SentencePieceProcessor.Tokenize = SentencePieceProcessor.Encode
+SentencePieceProcessor.Detokenize = SentencePieceProcessor.Decode
for m in [
'PieceToId', 'IdToPiece', 'GetScore', 'IsUnknown', 'IsControl', 'IsUnused',
diff --git a/python/sentencepiece.py b/python/sentencepiece.py
index 7fd116e..8f1b038 100644
--- a/python/sentencepiece.py
+++ b/python/sentencepiece.py
@@ -71,12 +71,6 @@ class SentencePieceProcessor(object):
_sentencepiece.SentencePieceProcessor_swiginit(self, _sentencepiece.new_SentencePieceProcessor())
__swig_destroy__ = _sentencepiece.delete_SentencePieceProcessor
- def Load(self, filename):
- return _sentencepiece.SentencePieceProcessor_Load(self, filename)
-
- def LoadOrDie(self, filename):
- return _sentencepiece.SentencePieceProcessor_LoadOrDie(self, filename)
-
def LoadFromSerializedProto(self, serialized):
return _sentencepiece.SentencePieceProcessor_LoadFromSerializedProto(self, serialized)
@@ -179,288 +173,301 @@ class SentencePieceProcessor(object):
def serialized_model_proto(self):
return _sentencepiece.SentencePieceProcessor_serialized_model_proto(self)
- def __len__(self):
- return _sentencepiece.SentencePieceProcessor___len__(self)
+ def LoadFromFile(self, arg):
+ return _sentencepiece.SentencePieceProcessor_LoadFromFile(self, arg)
+
+ def Init(self,
+ model_file=None,
+ model_proto=None,
+ out_type=int,
+ add_bos=False,
+ add_eos=False,
+ reverse=False,
+ enable_sampling=False,
+ nbest_size=-1,
+ alpha=0.1):
+ """Initialzie sentencepieceProcessor.
+
+ Args:
+ model_file: The sentencepiece model file path.
+ model_proto: The sentencepiece model serialized proto.
+ out_type: output type. int or str.
+ add_bos: Add <s> to the result (Default = false)
+ add_eos: Add </s> to the result (Default = false) <s>/</s> is added after
+ reversing (if enabled).
+ reverse: Reverses the tokenized sequence (Default = false)
+ nbest_size: sampling parameters for unigram. Invalid for BPE-Dropout.
+ nbest_size = {0,1}: No sampling is performed.
+ nbest_size > 1: samples from the nbest_size results.
+ nbest_size < 0: assuming that nbest_size is infinite and samples
+ from the all hypothesis (lattice) using
+ forward-filtering-and-backward-sampling algorithm.
+ alpha: Soothing parameter for unigram sampling, and merge probability for
+ BPE-dropout.
+ """
+
+ _sentencepiece_processor_init_native(self)
+ self._out_type = out_type
+ self._add_bos = add_bos
+ self._add_eos = add_eos
+ self._reverse = reverse
+ self._enable_sampling = enable_sampling
+ self._nbest_size = nbest_size
+ self._alpha = alpha
+ if model_file or model_proto:
+ self.Load(model_file=model_file, model_proto=model_proto)
+
+
+ def Encode(self,
+ input,
+ out_type=None,
+ add_bos=None,
+ add_eos=None,
+ reverse=None,
+ enable_sampling=None,
+ nbest_size=None,
+ alpha=None):
+ """Encode text input to segmented ids or tokens.
+
+ Args:
+ input: input string. accepsts list of string.
+ out_type: output type. int or str.
+ add_bos: Add <s> to the result (Default = false)
+ add_eos: Add </s> to the result (Default = false) <s>/</s> is added after
+ reversing (if enabled).
+ reverse: Reverses the tokenized sequence (Default = false)
+ nbest_size: sampling parameters for unigram. Invalid for BPE-Dropout.
+ nbest_size = {0,1}: No sampling is performed.
+ nbest_size > 1: samples from the nbest_size results.
+ nbest_size < 0: assuming that nbest_size is infinite and samples
+ from the all hypothesis (lattice) using
+ forward-filtering-and-backward-sampling algorithm.
+ alpha: Soothing parameter for unigram sampling, and merge probability for
+ BPE-dropout.
+ """
+
+ if out_type is None:
+ out_type = self._out_type
+ if add_bos is None:
+ add_bos = self._add_bos
+ if add_eos is None:
+ add_eos = self._add_eos
+ if reverse is None:
+ reverse = self._reverse
+ if enable_sampling is None:
+ enable_sampling = self._enable_sampling
+ if nbest_size is None:
+ nbest_size = self._nbest_size
+ if alpha is None:
+ alpha = self._alpha
+
+ if enable_sampling == True and (nbest_size is None or nbest_size == 0 or
+ nbest_size == 1 or alpha is None or
+ alpha <= 0.0 or alpha > 1.0):
+ raise RuntimeError(
+ 'When enable_sampling is True, We must specify "nbest_size > 1" or "nbest_size = -1", '
+ 'and "0.0 < alpha < 1.0". "nbest_size = -1" is enabled only on unigram mode and '
+ 'samples from all candidates on the lattice instead of nbest segmentations. '
+ )
+
+ def _encode(text):
+ if out_type is int:
+ if enable_sampling:
+ result = self.SampleEncodeAsIds(text, nbest_size, alpha)
+ else:
+ result = self.EncodeAsIds(text)
+ else:
+ if enable_sampling:
+ result = self.SampleEncodeAsPieces(text, nbest_size, alpha)
+ else:
+ result = self.EncodeAsPieces(text)
- def __getitem__(self, key):
- return _sentencepiece.SentencePieceProcessor___getitem__(self, key)
+ if reverse:
+ result.reverse()
+ if add_bos:
+ if out_type is int:
+ result = [self.bos_id()] + result
+ else:
+ result = [self.IdToPiece(self.bos_id())] + result
-# Register SentencePieceProcessor in _sentencepiece:
-_sentencepiece.SentencePieceProcessor_swigregister(SentencePieceProcessor)
+ if add_eos:
+ if out_type is int:
+ result = result + [self.eos_id()]
+ else:
+ result = result + [self.IdToPiece(self.eos_id())]
-class SentencePieceTrainer(object):
- thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag")
+ return result
- def __init__(self, *args, **kwargs):
- raise AttributeError("No constructor defined")
- __repr__ = _swig_repr
+ if type(input) is list:
+ return [_encode(n) for n in input]
- @staticmethod
- def TrainFromString(arg):
- return _sentencepiece.SentencePieceTrainer_TrainFromString(arg)
+ return _encode(input)
- @staticmethod
- def TrainFromMap(args):
- return _sentencepiece.SentencePieceTrainer_TrainFromMap(args)
- @staticmethod
- def TrainFromMap2(args, iter):
- return _sentencepiece.SentencePieceTrainer_TrainFromMap2(args, iter)
+ def Decode(self, input):
+ """Decode processed id or token sequences."""
- @staticmethod
- def TrainFromMap3(args):
- return _sentencepiece.SentencePieceTrainer_TrainFromMap3(args)
+ if not input:
+ return self.DecodeIds([])
+ elif type(input) is int:
+ return self.DecodeIds([input])
+ elif type(input) is str:
+ return self.DecodePieces([input])
- @staticmethod
- def TrainFromMap4(args, iter):
- return _sentencepiece.SentencePieceTrainer_TrainFromMap4(args, iter)
+ def _decode(input):
+ if not input:
+ return self.DecodeIds([])
+ if type(input[0]) is int:
+ return self.DecodeIds(input)
+ return self.DecodePieces(input)
-# Register SentencePieceTrainer in _sentencepiece:
-_sentencepiece.SentencePieceTrainer_swigregister(SentencePieceTrainer)
+ if type(input[0]) is list:
+ return [_decode(n) for n in input]
-def SentencePieceTrainer_TrainFromString(arg):
- return _sentencepiece.SentencePieceTrainer_TrainFromString(arg)
+ return _decode(input)
-def SentencePieceTrainer_TrainFromMap(args):
- return _sentencepiece.SentencePieceTrainer_TrainFromMap(args)
-def SentencePieceTrainer_TrainFromMap2(args, iter):
- return _sentencepiece.SentencePieceTrainer_TrainFromMap2(args, iter)
+ def piece_size(self):
+ return self.GetPieceSize()
-def SentencePieceTrainer_TrainFromMap3(args):
- return _sentencepiece.SentencePieceTrainer_TrainFromMap3(args)
-def SentencePieceTrainer_TrainFromMap4(args, iter):
- return _sentencepiece.SentencePieceTrainer_TrainFromMap4(args, iter)
+ def vocab_size(self):
+ return self.GetPieceSize()
+ def __getstate__(self):
+ return self.serialized_model_proto()
-import re
-import csv
-import sys
-from io import StringIO
-from io import BytesIO
+ def __setstate__(self, serialized_model_proto):
+ self.__init__()
+ self.LoadFromSerializedProto(serialized_model_proto)
-def _sentencepiece_processor_init(self,
- model_file=None,
- model_proto=None,
- out_type=int,
- add_bos=False,
- add_eos=False,
- reverse=False,
- enable_sampling=False,
- nbest_size=-1,
- alpha=0.1):
- """Overwride SentencePieceProcessor.__init__ to add addtional parameters.
-
- Args:
- model_file: The sentencepiece model file path.
- model_proto: The sentencepiece model serialized proto.
- out_type: output type. int or str.
- add_bos: Add <s> to the result (Default = false)
- add_eos: Add </s> to the result (Default = false) <s>/</s> is added after
- reversing (if enabled).
- reverse: Reverses the tokenized sequence (Default = false)
- nbest_size: sampling parameters for unigram. Invalid for BPE-Dropout.
- nbest_size = {0,1}: No sampling is performed.
- nbest_size > 1: samples from the nbest_size results.
- nbest_size < 0: assuming that nbest_size is infinite and samples
- from the all hypothesis (lattice) using
- forward-filtering-and-backward-sampling algorithm.
- alpha: Soothing parameter for unigram sampling, and merge probability for
- BPE-dropout.
- """
-
- _sentencepiece_processor_init_native(self)
- self._out_type = out_type
- self._add_bos = add_bos
- self._add_eos = add_eos
- self._reverse = reverse
- self._enable_sampling = enable_sampling
- self._nbest_size = nbest_size
- self._alpha = alpha
- if model_file or model_proto:
- self.Load(model_file=model_file, model_proto=model_proto)
-
-
-def _sentencepiece_processor_load(self, model_file=None, model_proto=None):
- """Overwride SentencePieceProcessor.Load to support both model_file and model_proto.
-
- Args:
- model_file: The sentencepiece model file path.
- model_proto: The sentencepiece model serialized proto. Either `model_file`
- or `model_proto` must be set.
- """
- if model_file and model_proto:
- raise RuntimeError('model_file and model_proto must be exclusive.')
- if model_proto:
- return self._LoadFromSerializedProto_native(model_proto)
- return self._Load_native(model_file)
-
-
-def _sentencepiece_processor_encode(self,
- input,
- out_type=None,
- add_bos=None,
- add_eos=None,
- reverse=None,
- enable_sampling=None,
- nbest_size=None,
- alpha=None):
- """Encode text input to segmented ids or tokens.
-
- Args:
- input: input string. accepsts list of string.
- out_type: output type. int or str.
- add_bos: Add <s> to the result (Default = false)
- add_eos: Add </s> to the result (Default = false) <s>/</s> is added after
- reversing (if enabled).
- reverse: Reverses the tokenized sequence (Default = false)
- nbest_size: sampling parameters for unigram. Invalid for BPE-Dropout.
- nbest_size = {0,1}: No sampling is performed.
- nbest_size > 1: samples from the nbest_size results.
- nbest_size < 0: assuming that nbest_size is infinite and samples
- from the all hypothesis (lattice) using
- forward-filtering-and-backward-sampling algorithm.
- alpha: Soothing parameter for unigram sampling, and merge probability for
- BPE-dropout.
- """
-
- if out_type is None:
- out_type = self._out_type
- if add_bos is None:
- add_bos = self._add_bos
- if add_eos is None:
- add_eos = self._add_eos
- if reverse is None:
- reverse = self._reverse
- if enable_sampling is None:
- enable_sampling = self._enable_sampling
- if nbest_size is None:
- nbest_size = self._nbest_size
- if alpha is None:
- alpha = self._alpha
-
- if enable_sampling == True and (nbest_size is None or nbest_size == 0 or
- nbest_size == 1 or alpha is None or
- alpha <= 0.0 or alpha > 1.0):
- raise RuntimeError(
- 'When enable_sampling is True, We must specify "nbest_size > 1" or "nbest_size = -1", '
- 'and "0.0 < alpha < 1.0". "nbest_size = -1" is enabled only on unigram mode and '
- 'samples from all candidates on the lattice instead of nbest segmentations. '
- )
-
- def _encode(text):
- if out_type is int:
- if enable_sampling:
- result = self.SampleEncodeAsIds(text, nbest_size, alpha)
- else:
- result = self.EncodeAsIds(text)
- else:
- if enable_sampling:
- result = self.SampleEncodeAsPieces(text, nbest_size, alpha)
- else:
- result = self.EncodeAsPieces(text)
- if reverse:
- result.reverse()
- if add_bos:
- if out_type is int:
- result = [self.bos_id()] + result
- else:
- result = [self.IdToPiece(self.bos_id())] + result
+ def __len__(self):
+ return self.GetPieceSize()
+
+
+ def __getitem__(self, piece):
+ return self.PieceToId(piece)
- if add_eos:
- if out_type is int:
- result = result + [self.eos_id()]
- else:
- result = result + [self.IdToPiece(self.eos_id())]
- return result
+ def Load(self, model_file=None, model_proto=None):
+ """Overwride SentencePieceProcessor.Load to support both model_file and model_proto.
- if type(input) is list:
- return [_encode(n) for n in input]
+ Args:
+ model_file: The sentencepiece model file path.
+ model_proto: The sentencepiece model serialized proto. Either `model_file`
+ or `model_proto` must be set.
+ """
+ if model_file and model_proto:
+ raise RuntimeError('model_file and model_proto must be exclusive.')
+ if model_proto:
+ return self.LoadFromSerializedProto(model_proto)
+ return self.LoadFromFile(model_file)
+
+
+# Register SentencePieceProcessor in _sentencepiece:
+_sentencepiece.SentencePieceProcessor_swigregister(SentencePieceProcessor)
- return _encode(input)
+class SentencePieceTrainer(object):
+ thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag")
+ def __init__(self, *args, **kwargs):
+ raise AttributeError("No constructor defined")
+ __repr__ = _swig_repr
-def _sentencepiece_processor_decode(self, input):
- """Decode processed id or token sequences."""
+ @staticmethod
+ def _TrainFromString(arg):
+ return _sentencepiece.SentencePieceTrainer__TrainFromString(arg)
- if not input:
- return self.DecodeIds([])
- elif type(input) is int:
- return self.DecodeIds([input])
- elif type(input) is str:
- return self.DecodePieces([input])
+ @staticmethod
+ def _TrainFromMap(args):
+ return _sentencepiece.SentencePieceTrainer__TrainFromMap(args)
- def _decode(input):
- if not input:
- return self.DecodeIds([])
- if type(input[0]) is int:
- return self.DecodeIds(input)
- return self.DecodePieces(input)
+ @staticmethod
+ def _TrainFromMap2(args, iter):
+ return _sentencepiece.SentencePieceTrainer__TrainFromMap2(args, iter)
- if type(input[0]) is list:
- return [_decode(n) for n in input]
+ @staticmethod
+ def _TrainFromMap3(args):
+ return _sentencepiece.SentencePieceTrainer__TrainFromMap3(args)
- return _decode(input)
+ @staticmethod
+ def _TrainFromMap4(args, iter):
+ return _sentencepiece.SentencePieceTrainer__TrainFromMap4(args, iter)
-def _sentencepiece_trainer_train(arg=None, **kwargs):
- """Train Sentencepiece model. Accept both kwargs and legacy string arg."""
- if arg is not None and type(arg) is str:
- return SentencePieceTrainer.TrainFromString(arg)
+ @staticmethod
+ def Train(arg=None, **kwargs):
+ """Train Sentencepiece model. Accept both kwargs and legacy string arg."""
+ if arg is not None and type(arg) is str:
+ return SentencePieceTrainer._TrainFromString(arg)
+
+ def _encode(value):
+ """Encode value to CSV.."""
+ if type(value) is list:
+ if sys.version_info[0] == 3:
+ f = StringIO()
+ else:
+ f = BytesIO()
+ writer = csv.writer(f, lineterminator='')
+ writer.writerow([str(v) for v in value])
+ return f.getvalue()
+ else:
+ return str(value)
+
+ sentence_iterator = None
+ model_writer = None
+ new_kwargs = {}
+ for key, value in kwargs.items():
+ if key in ['sentence_iterator', 'sentence_reader']:
+ sentence_iterator = value
+ elif key in ['model_writer']:
+ model_writer = value
+ else:
+ new_kwargs[key] = _encode(value)
- def _encode(value):
- """Encode value to CSV.."""
- if type(value) is list:
- if sys.version_info[0] == 3:
- f = StringIO()
+ if model_writer:
+ if sentence_iterator:
+ model_proto = SentencePieceTrainer._TrainFromMap4(new_kwargs,
+ sentence_iterator)
+ else:
+ model_proto = SentencePieceTrainer._TrainFromMap3(new_kwargs)
+ model_writer.write(model_proto)
else:
- f = BytesIO()
- writer = csv.writer(f, lineterminator='')
- writer.writerow([str(v) for v in value])
- return f.getvalue()
- else:
- return str(value)
-
- sentence_iterator = None
- model_writer = None
- new_kwargs = {}
- for key, value in kwargs.items():
- if key in ['sentence_iterator', 'sentence_reader']:
- sentence_iterator = value
- elif key in ['model_writer']:
- model_writer = value
- else:
- new_kwargs[key] = _encode(value)
+ if sentence_iterator:
+ return SentencePieceTrainer._TrainFromMap2(new_kwargs, sentence_iterator)
+ else:
+ return SentencePieceTrainer._TrainFromMap(new_kwargs)
- if model_writer:
- if sentence_iterator:
- model_proto = SentencePieceTrainer.TrainFromMap4(new_kwargs,
- sentence_iterator)
- else:
- model_proto = SentencePieceTrainer.TrainFromMap3(new_kwargs)
- model_writer.write(model_proto)
- else:
- if sentence_iterator:
- return SentencePieceTrainer.TrainFromMap2(new_kwargs, sentence_iterator)
- else:
- return SentencePieceTrainer.TrainFromMap(new_kwargs)
+ return None
- return None
+# Register SentencePieceTrainer in _sentencepiece:
+_sentencepiece.SentencePieceTrainer_swigregister(SentencePieceTrainer)
-def _save_native(classname):
- """Stores the origina method as _{method_name}_native."""
+def SentencePieceTrainer__TrainFromString(arg):
+ return _sentencepiece.SentencePieceTrainer__TrainFromString(arg)
- native_map = {}
- for name, method in classname.__dict__.items():
- if name[0] != '_':
- native_map[('_%s_native' % name)] = method
- for k, v in native_map.items():
- setattr(classname, k, v)
+def SentencePieceTrainer__TrainFromMap(args):
+ return _sentencepiece.SentencePieceTrainer__TrainFromMap(args)
+
+def SentencePieceTrainer__TrainFromMap2(args, iter):
+ return _sentencepiece.SentencePieceTrainer__TrainFromMap2(args, iter)
+
+def SentencePieceTrainer__TrainFromMap3(args):
+ return _sentencepiece.SentencePieceTrainer__TrainFromMap3(args)
+
+def SentencePieceTrainer__TrainFromMap4(args, iter):
+ return _sentencepiece.SentencePieceTrainer__TrainFromMap4(args, iter)
+
+
+
+import re
+import csv
+import sys
+from io import StringIO
+from io import BytesIO
def _add_snake_case(classname):
@@ -478,8 +485,7 @@ def _add_snake_case(classname):
def _batchnize(classname, name):
"""Enables batch request for the method classname.name."""
-
- func = getattr(classname, '_%s_native' % name, None)
+ func = getattr(classname, name, None)
def _batched_func(self, arg):
if type(arg) is list:
@@ -490,17 +496,11 @@ def _batchnize(classname, name):
setattr(classname, name, _batched_func)
-_save_native(SentencePieceProcessor)
_sentencepiece_processor_init_native = SentencePieceProcessor.__init__
-setattr(SentencePieceProcessor, 'Encode', _sentencepiece_processor_encode)
-setattr(SentencePieceProcessor, 'Tokenize', _sentencepiece_processor_encode)
-setattr(SentencePieceProcessor, 'Decode', _sentencepiece_processor_decode)
-setattr(SentencePieceProcessor, 'Detokenize', _sentencepiece_processor_decode)
-setattr(SentencePieceProcessor, 'Load', _sentencepiece_processor_load)
-setattr(SentencePieceProcessor, '__init__', _sentencepiece_processor_init)
-setattr(SentencePieceProcessor, 'vocab_size', SentencePieceProcessor.GetPieceSize)
-setattr(SentencePieceProcessor, 'piece_size', SentencePieceProcessor.GetPieceSize)
-setattr(SentencePieceTrainer, 'Train', staticmethod(_sentencepiece_trainer_train))
+setattr(SentencePieceProcessor, '__init__', SentencePieceProcessor.Init)
+
+SentencePieceProcessor.Tokenize = SentencePieceProcessor.Encode
+SentencePieceProcessor.Detokenize = SentencePieceProcessor.Decode
for m in [
'PieceToId', 'IdToPiece', 'GetScore', 'IsUnknown', 'IsControl', 'IsUnused',
diff --git a/python/sentencepiece_wrap.cxx b/python/sentencepiece_wrap.cxx
index 12162b6..e47b780 100644
--- a/python/sentencepiece_wrap.cxx
+++ b/python/sentencepiece_wrap.cxx
@@ -3280,34 +3280,31 @@ SWIGINTERNINLINE PyObject*
return PyBool_FromLong(value ? 1 : 0);
}
-SWIGINTERN int sentencepiece_SentencePieceProcessor___len__(sentencepiece::SentencePieceProcessor *self){
- return self->GetPieceSize();
+SWIGINTERN sentencepiece::util::Status sentencepiece_SentencePieceProcessor_LoadFromFile(sentencepiece::SentencePieceProcessor *self,absl::string_view arg){
+ return self->Load(arg);
}
-SWIGINTERN int sentencepiece_SentencePieceProcessor___getitem__(sentencepiece::SentencePieceProcessor const *self,absl::string_view key){
- return self->PieceToId(key);
- }
-SWIGINTERN void sentencepiece_SentencePieceTrainer_TrainFromString(absl::string_view arg){
+SWIGINTERN void sentencepiece_SentencePieceTrainer__TrainFromString(absl::string_view arg){
const auto _status = sentencepiece::SentencePieceTrainer::Train(arg);
if (!_status.ok()) throw _status;
return;
}
-SWIGINTERN void sentencepiece_SentencePieceTrainer_TrainFromMap(std::map< std::string,std::string > const &args){
+SWIGINTERN void sentencepiece_SentencePieceTrainer__TrainFromMap(std::map< std::string,std::string > const &args){
const auto _status = sentencepiece::SentencePieceTrainer::Train(args);
if (!_status.ok()) throw _status;
return;
}
-SWIGINTERN void sentencepiece_SentencePieceTrainer_TrainFromMap2(std::map< std::string,std::string > const &args,sentencepiece::SentenceIterator *iter){
+SWIGINTERN void sentencepiece_SentencePieceTrainer__TrainFromMap2(std::map< std::string,std::string > const &args,sentencepiece::SentenceIterator *iter){
const auto _status = sentencepiece::SentencePieceTrainer::Train(args, iter);
if (!_status.ok()) throw _status;
return;
}
-SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceTrainer_TrainFromMap3(std::map< std::string,std::string > const &args){
+SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceTrainer__TrainFromMap3(std::map< std::string,std::string > const &args){
sentencepiece::util::bytes model_proto;
const auto _status = sentencepiece::SentencePieceTrainer::Train(args, nullptr, &model_proto);
if (!_status.ok()) throw _status;
return model_proto;
}
-SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceTrainer_TrainFromMap4(std::map< std::string,std::string > const &args,sentencepiece::SentenceIterator *iter){
+SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceTrainer__TrainFromMap4(std::map< std::string,std::string > const &args,sentencepiece::SentenceIterator *iter){
sentencepiece::util::bytes model_proto;
const auto _status = sentencepiece::SentencePieceTrainer::Train(args, iter, &model_proto);
if (!_status.ok()) throw _status;
@@ -3367,90 +3364,6 @@ fail:
}
-SWIGINTERN PyObject *_wrap_SentencePieceProcessor_Load(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
- PyObject *resultobj = 0;
- sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ;
- absl::string_view arg2 ;
- void *argp1 = 0 ;
- int res1 = 0 ;
- PyObject *swig_obj[2] ;
- sentencepiece::util::Status result;
-
- if (!SWIG_Python_UnpackTuple(args, "SentencePieceProcessor_Load", 2, 2, swig_obj)) SWIG_fail;
- res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 );
- if (!SWIG_IsOK(res1)) {
- SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor_Load" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor *""'");
- }
- arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1);
- {
- const PyInputString ustring(swig_obj[1]);
- if (!ustring.IsAvalable()) {
- PyErr_SetString(PyExc_TypeError, "not a string");
- SWIG_fail;
- }
- resultobj = ustring.input_type();
- arg2 = absl::string_view(ustring.data(), ustring.size());
- }
- {
- try {
- result = (arg1)->Load(arg2);
- ReleaseResultObject(resultobj);
- }
- catch (const sentencepiece::util::Status &status) {
- SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
- }
- }
- {
- if (!(&result)->ok()) {
- SWIG_exception(ToSwigError((&result)->code()), (&result)->ToString().c_str());
- }
- resultobj = SWIG_From_bool((&result)->ok());
- }
- return resultobj;
-fail:
- return NULL;
-}
-
-
-SWIGINTERN PyObject *_wrap_SentencePieceProcessor_LoadOrDie(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
- PyObject *resultobj = 0;
- sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ;
- absl::string_view arg2 ;
- void *argp1 = 0 ;
- int res1 = 0 ;
- PyObject *swig_obj[2] ;
-
- if (!SWIG_Python_UnpackTuple(args, "SentencePieceProcessor_LoadOrDie", 2, 2, swig_obj)) SWIG_fail;
- res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 );
- if (!SWIG_IsOK(res1)) {
- SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor_LoadOrDie" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor *""'");
- }
- arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1);
- {
- const PyInputString ustring(swig_obj[1]);
- if (!ustring.IsAvalable()) {
- PyErr_SetString(PyExc_TypeError, "not a string");
- SWIG_fail;
- }
- resultobj = ustring.input_type();
- arg2 = absl::string_view(ustring.data(), ustring.size());
- }
- {
- try {
- (arg1)->LoadOrDie(arg2);
- ReleaseResultObject(resultobj);
- }
- catch (const sentencepiece::util::Status &status) {
- SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
- }
- }
- resultobj = SWIG_Py_Void();
- return resultobj;
-fail:
- return NULL;
-}
-
-
SWIGINTERN PyObject *_wrap_SentencePieceProcessor_LoadFromSerializedProto(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
PyObject *resultobj = 0;
sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ;
@@ -4990,50 +4903,19 @@ fail:
}
-SWIGINTERN PyObject *_wrap_SentencePieceProcessor___len__(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
- PyObject *resultobj = 0;
- sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ;
- void *argp1 = 0 ;
- int res1 = 0 ;
- PyObject *swig_obj[1] ;
- int result;
-
- if (!args) SWIG_fail;
- swig_obj[0] = args;
- res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 );
- if (!SWIG_IsOK(res1)) {
- SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor___len__" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor *""'");
- }
- arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1);
- {
- try {
- result = (int)sentencepiece_SentencePieceProcessor___len__(arg1);
- ReleaseResultObject(resultobj);
- }
- catch (const sentencepiece::util::Status &status) {
- SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
- }
- }
- resultobj = SWIG_From_int(static_cast< int >(result));
- return resultobj;
-fail:
- return NULL;
-}
-
-
-SWIGINTERN PyObject *_wrap_SentencePieceProcessor___getitem__(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
+SWIGINTERN PyObject *_wrap_SentencePieceProcessor_LoadFromFile(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
PyObject *resultobj = 0;
sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ;
absl::string_view arg2 ;
void *argp1 = 0 ;
int res1 = 0 ;
PyObject *swig_obj[2] ;
- int result;
+ sentencepiece::util::Status result;
- if (!SWIG_Python_UnpackTuple(args, "SentencePieceProcessor___getitem__", 2, 2, swig_obj)) SWIG_fail;
+ if (!SWIG_Python_UnpackTuple(args, "SentencePieceProcessor_LoadFromFile", 2, 2, swig_obj)) SWIG_fail;
res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 );
if (!SWIG_IsOK(res1)) {
- SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor___getitem__" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor const *""'");
+ SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor_LoadFromFile" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor *""'");
}
arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1);
{
@@ -5047,14 +4929,19 @@ SWIGINTERN PyObject *_wrap_SentencePieceProcessor___getitem__(PyObject *SWIGUNUS
}
{
try {
- result = (int)sentencepiece_SentencePieceProcessor___getitem__((sentencepiece::SentencePieceProcessor const *)arg1,arg2);
+ result = sentencepiece_SentencePieceProcessor_LoadFromFile(arg1,arg2);
ReleaseResultObject(resultobj);
}
catch (const sentencepiece::util::Status &status) {
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
}
}
- resultobj = SWIG_From_int(static_cast< int >(result));
+ {
+ if (!(&result)->ok()) {
+ SWIG_exception(ToSwigError((&result)->code()), (&result)->ToString().c_str());
+ }
+ resultobj = SWIG_From_bool((&result)->ok());
+ }
return resultobj;
fail:
return NULL;
@@ -5072,7 +4959,7 @@ SWIGINTERN PyObject *SentencePieceProcessor_swiginit(PyObject *SWIGUNUSEDPARM(se
return SWIG_Python_InitShadowInstance(args);
}
-SWIGINTERN PyObject *_wrap_SentencePieceTrainer_TrainFromString(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
+SWIGINTERN PyObject *_wrap_SentencePieceTrainer__TrainFromString(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
PyObject *resultobj = 0;
absl::string_view arg1 ;
PyObject *swig_obj[1] ;
@@ -5090,7 +4977,7 @@ SWIGINTERN PyObject *_wrap_SentencePieceTrainer_TrainFromString(PyObject *SWIGUN
}
{
try {
- sentencepiece_SentencePieceTrainer_TrainFromString(arg1);
+ sentencepiece_SentencePieceTrainer__TrainFromString(arg1);
ReleaseResultObject(resultobj);
}
catch (const sentencepiece::util::Status &status) {
@@ -5104,7 +4991,7 @@ fail:
}
-SWIGINTERN PyObject *_wrap_SentencePieceTrainer_TrainFromMap(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
+SWIGINTERN PyObject *_wrap_SentencePieceTrainer__TrainFromMap(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
PyObject *resultobj = 0;
std::map< std::string,std::string > *arg1 = 0 ;
PyObject *swig_obj[1] ;
@@ -5137,7 +5024,7 @@ SWIGINTERN PyObject *_wrap_SentencePieceTrainer_TrainFromMap(PyObject *SWIGUNUSE
}
{
try {
- sentencepiece_SentencePieceTrainer_TrainFromMap((std::map< std::string,std::string > const &)*arg1);
+ sentencepiece_SentencePieceTrainer__TrainFromMap((std::map< std::string,std::string > const &)*arg1);
ReleaseResultObject(resultobj);
}
catch (const sentencepiece::util::Status &status) {
@@ -5157,13 +5044,13 @@ fail:
}
-SWIGINTERN PyObject *_wrap_SentencePieceTrainer_TrainFromMap2(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
+SWIGINTERN PyObject *_wrap_SentencePieceTrainer__TrainFromMap2(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
PyObject *resultobj = 0;
std::map< std::string,std::string > *arg1 = 0 ;
sentencepiece::SentenceIterator *arg2 = (sentencepiece::SentenceIterator *) 0 ;
PyObject *swig_obj[2] ;
- if (!SWIG_Python_UnpackTuple(args, "SentencePieceTrainer_TrainFromMap2", 2, 2, swig_obj)) SWIG_fail;
+ if (!SWIG_Python_UnpackTuple(args, "SentencePieceTrainer__TrainFromMap2", 2, 2, swig_obj)) SWIG_fail;
{
std::map<std::string, std::string> *out = nullptr;
if (PyDict_Check(swig_obj[0])) {
@@ -5200,7 +5087,7 @@ SWIGINTERN PyObject *_wrap_SentencePieceTrainer_TrainFromMap2(PyObject *SWIGUNUS
}
{
try {
- sentencepiece_SentencePieceTrainer_TrainFromMap2((std::map< std::string,std::string > const &)*arg1,arg2);
+ sentencepiece_SentencePieceTrainer__TrainFromMap2((std::map< std::string,std::string > const &)*arg1,arg2);
ReleaseResultObject(resultobj);
}
catch (const sentencepiece::util::Status &status) {
@@ -5226,7 +5113,7 @@ fail:
}
-SWIGINTERN PyObject *_wrap_SentencePieceTrainer_TrainFromMap3(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
+SWIGINTERN PyObject *_wrap_SentencePieceTrainer__TrainFromMap3(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
PyObject *resultobj = 0;
std::map< std::string,std::string > *arg1 = 0 ;
PyObject *swig_obj[1] ;
@@ -5260,7 +5147,7 @@ SWIGINTERN PyObject *_wrap_SentencePieceTrainer_TrainFromMap3(PyObject *SWIGUNUS
}
{
try {
- result = sentencepiece_SentencePieceTrainer_TrainFromMap3((std::map< std::string,std::string > const &)*arg1);
+ result = sentencepiece_SentencePieceTrainer__TrainFromMap3((std::map< std::string,std::string > const &)*arg1);
ReleaseResultObject(resultobj);
}
catch (const sentencepiece::util::Status &status) {
@@ -5282,14 +5169,14 @@ fail:
}
-SWIGINTERN PyObject *_wrap_SentencePieceTrainer_TrainFromMap4(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
+SWIGINTERN PyObject *_wrap_SentencePieceTrainer__TrainFromMap4(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
PyObject *resultobj = 0;
std::map< std::string,std::string > *arg1 = 0 ;
sentencepiece::SentenceIterator *arg2 = (sentencepiece::SentenceIterator *) 0 ;
PyObject *swig_obj[2] ;
sentencepiece::util::bytes result;
- if (!SWIG_Python_UnpackTuple(args, "SentencePieceTrainer_TrainFromMap4", 2, 2, swig_obj)) SWIG_fail;
+ if (!SWIG_Python_UnpackTuple(args, "SentencePieceTrainer__TrainFromMap4", 2, 2, swig_obj)) SWIG_fail;
{
std::map<std::string, std::string> *out = nullptr;
if (PyDict_Check(swig_obj[0])) {
@@ -5326,7 +5213,7 @@ SWIGINTERN PyObject *_wrap_SentencePieceTrainer_TrainFromMap4(PyObject *SWIGUNUS
}
{
try {
- result = sentencepiece_SentencePieceTrainer_TrainFromMap4((std::map< std::string,std::string > const &)*arg1,arg2);
+ result = sentencepiece_SentencePieceTrainer__TrainFromMap4((std::map< std::string,std::string > const &)*arg1,arg2);
ReleaseResultObject(resultobj);
}
catch (const sentencepiece::util::Status &status) {
@@ -5365,8 +5252,6 @@ static PyMethodDef SwigMethods[] = {
{ "SWIG_PyInstanceMethod_New", SWIG_PyInstanceMethod_New, METH_O, NULL},
{ "new_SentencePieceProcessor", _wrap_new_SentencePieceProcessor, METH_NOARGS, NULL},
{ "delete_SentencePieceProcessor", _wrap_delete_SentencePieceProcessor, METH_O, NULL},
- { "SentencePieceProcessor_Load", _wrap_SentencePieceProcessor_Load, METH_VARARGS, NULL},
- { "SentencePieceProcessor_LoadOrDie", _wrap_SentencePieceProcessor_LoadOrDie, METH_VARARGS, NULL},
{ "SentencePieceProcessor_LoadFromSerializedProto", _wrap_SentencePieceProcessor_LoadFromSerializedProto, METH_VARARGS, NULL},
{ "SentencePieceProcessor_SetEncodeExtraOptions", _wrap_SentencePieceProcessor_SetEncodeExtraOptions, METH_VARARGS, NULL},
{ "SentencePieceProcessor_SetDecodeExtraOptions", _wrap_SentencePieceProcessor_SetDecodeExtraOptions, METH_VARARGS, NULL},
@@ -5401,15 +5286,14 @@ static PyMethodDef SwigMethods[] = {
{ "SentencePieceProcessor_eos_id", _wrap_SentencePieceProcessor_eos_id, METH_O, NULL},
{ "SentencePieceProcessor_pad_id", _wrap_SentencePieceProcessor_pad_id, METH_O, NULL},
{ "SentencePieceProcessor_serialized_model_proto", _wrap_SentencePieceProcessor_serialized_model_proto, METH_O, NULL},
- { "SentencePieceProcessor___len__", _wrap_SentencePieceProcessor___len__, METH_O, NULL},
- { "SentencePieceProcessor___getitem__", _wrap_SentencePieceProcessor___getitem__, METH_VARARGS, NULL},
+ { "SentencePieceProcessor_LoadFromFile", _wrap_SentencePieceProcessor_LoadFromFile, METH_VARARGS, NULL},
{ "SentencePieceProcessor_swigregister", SentencePieceProcessor_swigregister, METH_O, NULL},
{ "SentencePieceProcessor_swiginit", SentencePieceProcessor_swiginit, METH_VARARGS, NULL},
- { "SentencePieceTrainer_TrainFromString", _wrap_SentencePieceTrainer_TrainFromString, METH_O, NULL},
- { "SentencePieceTrainer_TrainFromMap", _wrap_SentencePieceTrainer_TrainFromMap, METH_O, NULL},
- { "SentencePieceTrainer_TrainFromMap2", _wrap_SentencePieceTrainer_TrainFromMap2, METH_VARARGS, NULL},
- { "SentencePieceTrainer_TrainFromMap3", _wrap_SentencePieceTrainer_TrainFromMap3, METH_O, NULL},
- { "SentencePieceTrainer_TrainFromMap4", _wrap_SentencePieceTrainer_TrainFromMap4, METH_VARARGS, NULL},
+ { "SentencePieceTrainer__TrainFromString", _wrap_SentencePieceTrainer__TrainFromString, METH_O, NULL},
+ { "SentencePieceTrainer__TrainFromMap", _wrap_SentencePieceTrainer__TrainFromMap, METH_O, NULL},
+ { "SentencePieceTrainer__TrainFromMap2", _wrap_SentencePieceTrainer__TrainFromMap2, METH_VARARGS, NULL},
+ { "SentencePieceTrainer__TrainFromMap3", _wrap_SentencePieceTrainer__TrainFromMap3, METH_O, NULL},
+ { "SentencePieceTrainer__TrainFromMap4", _wrap_SentencePieceTrainer__TrainFromMap4, METH_VARARGS, NULL},
{ "SentencePieceTrainer_swigregister", SentencePieceTrainer_swigregister, METH_O, NULL},
{ NULL, NULL, 0, NULL }
};
diff --git a/python/test/sentencepiece_test.py b/python/test/sentencepiece_test.py
index 59f062e..d16bd35 100755
--- a/python/test/sentencepiece_test.py
+++ b/python/test/sentencepiece_test.py
@@ -21,6 +21,7 @@ import sentencepiece as spm
import unittest
import sys
import os
+import pickle
from collections import defaultdict
@@ -131,6 +132,19 @@ class TestSentencepieceProcessor(unittest.TestCase):
text = text.encode('utf-8')
self.assertEqual(text, self.jasp_.DecodeIds(ids))
+ def test_pickle(self):
+ with open('sp.pickle', 'wb') as f:
+ pickle.dump(self.sp_, f)
+
+ id1 = self.sp_.encode('hello world.', out_type=int)
+
+ with open('sp.pickle', 'rb') as f:
+ sp = pickle.load(f)
+
+ id2 = sp.encode('hello world.', out_type=int)
+
+ self.assertEqual(id1, id2)
+
def test_train(self):
spm.SentencePieceTrainer.Train('--input=' +
os.path.join(data_dir, 'botchan.txt') +