diff options
author | Taku Kudo <taku@google.com> | 2020-05-18 05:18:57 +0300 |
---|---|---|
committer | Taku Kudo <taku@google.com> | 2020-05-18 05:18:57 +0300 |
commit | fbaf2f9b9b59cdd850b797300ef6b9b9ac3fe0af (patch) | |
tree | 946f3bc494e7bab85e979b15b24eb07da9864a6f | |
parent | 43a504eeead2b3226c3675c10b46edfba804f85f (diff) |
supported pickle serialization
-rw-r--r-- | python/.gitignore | 1 | ||||
-rw-r--r-- | python/sentencepiece.i | 501 | ||||
-rw-r--r-- | python/sentencepiece.py | 522 | ||||
-rw-r--r-- | python/sentencepiece_wrap.cxx | 188 | ||||
-rwxr-xr-x | python/test/sentencepiece_test.py | 14 |
5 files changed, 564 insertions, 662 deletions
diff --git a/python/.gitignore b/python/.gitignore index e4b2fd9..78e215b 100644 --- a/python/.gitignore +++ b/python/.gitignore @@ -1,2 +1,3 @@ /*.so /build +/*.pickle diff --git a/python/sentencepiece.i b/python/sentencepiece.i index 28e178a..e909506 100644 --- a/python/sentencepiece.i +++ b/python/sentencepiece.i @@ -171,15 +171,12 @@ class PySentenceIterator : public sentencepiece::SentenceIterator { %ignore sentencepiece::SentencePieceProcessor::status; %ignore sentencepiece::SentencePieceProcessor::Encode; -%ignore sentencepiece::SentencePieceProcessor::Encode; %ignore sentencepiece::SentencePieceProcessor::SampleEncode; %ignore sentencepiece::SentencePieceProcessor::NBestEncode; %ignore sentencepiece::SentencePieceProcessor::Decode; %ignore sentencepiece::SentencePieceProcessor::model_proto; -%ignore sentencepiece::SentencePieceProcessor::Load(std::istream *); -%ignore sentencepiece::SentencePieceProcessor::LoadOrDie(std::istream *); -%ignore sentencepiece::SentencePieceProcessor::Load(const ModelProto &); -%ignore sentencepiece::SentencePieceProcessor::Load(std::unique_ptr<ModelProto>); +%ignore sentencepiece::SentencePieceProcessor::Load; +%ignore sentencepiece::SentencePieceProcessor::LoadOrDie; %ignore sentencepiece::pretokenizer::PretokenizerForTrainingInterface; %ignore sentencepiece::SentenceIterator; %ignore sentencepiece::SentencePieceTrainer::Train; @@ -193,50 +190,284 @@ class PySentenceIterator : public sentencepiece::SentenceIterator { %ignore sentencepiece::SentencePieceTrainer::GetPretokenizerForTraining; %extend sentencepiece::SentencePieceProcessor { - - int __len__() { - return $self->GetPieceSize(); + sentencepiece::util::Status LoadFromFile(absl::string_view arg) { + return $self->Load(arg); } - int __getitem__(absl::string_view key) const { - return $self->PieceToId(key); - } +%pythoncode { + def Init(self, + model_file=None, + model_proto=None, + out_type=int, + add_bos=False, + add_eos=False, + reverse=False, + enable_sampling=False, + nbest_size=-1, + alpha=0.1): + """Initialzie sentencepieceProcessor. + + Args: + model_file: The sentencepiece model file path. + model_proto: The sentencepiece model serialized proto. + out_type: output type. int or str. + add_bos: Add <s> to the result (Default = false) + add_eos: Add </s> to the result (Default = false) <s>/</s> is added after + reversing (if enabled). + reverse: Reverses the tokenized sequence (Default = false) + nbest_size: sampling parameters for unigram. Invalid for BPE-Dropout. + nbest_size = {0,1}: No sampling is performed. + nbest_size > 1: samples from the nbest_size results. + nbest_size < 0: assuming that nbest_size is infinite and samples + from the all hypothesis (lattice) using + forward-filtering-and-backward-sampling algorithm. + alpha: Soothing parameter for unigram sampling, and merge probability for + BPE-dropout. + """ + + _sentencepiece_processor_init_native(self) + self._out_type = out_type + self._add_bos = add_bos + self._add_eos = add_eos + self._reverse = reverse + self._enable_sampling = enable_sampling + self._nbest_size = nbest_size + self._alpha = alpha + if model_file or model_proto: + self.Load(model_file=model_file, model_proto=model_proto) + + + def Encode(self, + input, + out_type=None, + add_bos=None, + add_eos=None, + reverse=None, + enable_sampling=None, + nbest_size=None, + alpha=None): + """Encode text input to segmented ids or tokens. + + Args: + input: input string. accepsts list of string. + out_type: output type. int or str. + add_bos: Add <s> to the result (Default = false) + add_eos: Add </s> to the result (Default = false) <s>/</s> is added after + reversing (if enabled). + reverse: Reverses the tokenized sequence (Default = false) + nbest_size: sampling parameters for unigram. Invalid for BPE-Dropout. + nbest_size = {0,1}: No sampling is performed. + nbest_size > 1: samples from the nbest_size results. + nbest_size < 0: assuming that nbest_size is infinite and samples + from the all hypothesis (lattice) using + forward-filtering-and-backward-sampling algorithm. + alpha: Soothing parameter for unigram sampling, and merge probability for + BPE-dropout. + """ + + if out_type is None: + out_type = self._out_type + if add_bos is None: + add_bos = self._add_bos + if add_eos is None: + add_eos = self._add_eos + if reverse is None: + reverse = self._reverse + if enable_sampling is None: + enable_sampling = self._enable_sampling + if nbest_size is None: + nbest_size = self._nbest_size + if alpha is None: + alpha = self._alpha + + if enable_sampling == True and (nbest_size is None or nbest_size == 0 or + nbest_size == 1 or alpha is None or + alpha <= 0.0 or alpha > 1.0): + raise RuntimeError( + 'When enable_sampling is True, We must specify "nbest_size > 1" or "nbest_size = -1", ' + 'and "0.0 < alpha < 1.0". "nbest_size = -1" is enabled only on unigram mode and ' + 'samples from all candidates on the lattice instead of nbest segmentations. ' + ) + + def _encode(text): + if out_type is int: + if enable_sampling: + result = self.SampleEncodeAsIds(text, nbest_size, alpha) + else: + result = self.EncodeAsIds(text) + else: + if enable_sampling: + result = self.SampleEncodeAsPieces(text, nbest_size, alpha) + else: + result = self.EncodeAsPieces(text) + + if reverse: + result.reverse() + if add_bos: + if out_type is int: + result = [self.bos_id()] + result + else: + result = [self.IdToPiece(self.bos_id())] + result + + if add_eos: + if out_type is int: + result = result + [self.eos_id()] + else: + result = result + [self.IdToPiece(self.eos_id())] + + return result + + if type(input) is list: + return [_encode(n) for n in input] + + return _encode(input) + + + def Decode(self, input): + """Decode processed id or token sequences.""" + + if not input: + return self.DecodeIds([]) + elif type(input) is int: + return self.DecodeIds([input]) + elif type(input) is str: + return self.DecodePieces([input]) + + def _decode(input): + if not input: + return self.DecodeIds([]) + if type(input[0]) is int: + return self.DecodeIds(input) + return self.DecodePieces(input) + + if type(input[0]) is list: + return [_decode(n) for n in input] + + return _decode(input) + + + def piece_size(self): + return self.GetPieceSize() + + + def vocab_size(self): + return self.GetPieceSize() + + + def __getstate__(self): + return self.serialized_model_proto() + + + def __setstate__(self, serialized_model_proto): + self.__init__() + self.LoadFromSerializedProto(serialized_model_proto) + + + def __len__(self): + return self.GetPieceSize() + + + def __getitem__(self, piece): + return self.PieceToId(piece) + + + def Load(self, model_file=None, model_proto=None): + """Overwride SentencePieceProcessor.Load to support both model_file and model_proto. + + Args: + model_file: The sentencepiece model file path. + model_proto: The sentencepiece model serialized proto. Either `model_file` + or `model_proto` must be set. + """ + if model_file and model_proto: + raise RuntimeError('model_file and model_proto must be exclusive.') + if model_proto: + return self.LoadFromSerializedProto(model_proto) + return self.LoadFromFile(model_file) +} } %extend sentencepiece::SentencePieceTrainer { - static void TrainFromString(absl::string_view arg) { + static void _TrainFromString(absl::string_view arg) { const auto _status = sentencepiece::SentencePieceTrainer::Train(arg); if (!_status.ok()) throw _status; return; } - static void TrainFromMap(const std::map<std::string, std::string> &args) { + static void _TrainFromMap(const std::map<std::string, std::string> &args) { const auto _status = sentencepiece::SentencePieceTrainer::Train(args); if (!_status.ok()) throw _status; return; } - static void TrainFromMap2(const std::map<std::string, std::string> &args, + static void _TrainFromMap2(const std::map<std::string, std::string> &args, SentenceIterator *iter) { const auto _status = sentencepiece::SentencePieceTrainer::Train(args, iter); if (!_status.ok()) throw _status; return; } - static sentencepiece::util::bytes TrainFromMap3(const std::map<std::string, std::string> &args) { + static sentencepiece::util::bytes _TrainFromMap3(const std::map<std::string, std::string> &args) { sentencepiece::util::bytes model_proto; const auto _status = sentencepiece::SentencePieceTrainer::Train(args, nullptr, &model_proto); if (!_status.ok()) throw _status; return model_proto; } - static sentencepiece::util::bytes TrainFromMap4(const std::map<std::string, std::string> &args, + static sentencepiece::util::bytes _TrainFromMap4(const std::map<std::string, std::string> &args, SentenceIterator *iter) { sentencepiece::util::bytes model_proto; const auto _status = sentencepiece::SentencePieceTrainer::Train(args, iter, &model_proto); if (!_status.ok()) throw _status; return model_proto; } + +%pythoncode { + @staticmethod + def Train(arg=None, **kwargs): + """Train Sentencepiece model. Accept both kwargs and legacy string arg.""" + if arg is not None and type(arg) is str: + return SentencePieceTrainer._TrainFromString(arg) + + def _encode(value): + """Encode value to CSV..""" + if type(value) is list: + if sys.version_info[0] == 3: + f = StringIO() + else: + f = BytesIO() + writer = csv.writer(f, lineterminator='') + writer.writerow([str(v) for v in value]) + return f.getvalue() + else: + return str(value) + + sentence_iterator = None + model_writer = None + new_kwargs = {} + for key, value in kwargs.items(): + if key in ['sentence_iterator', 'sentence_reader']: + sentence_iterator = value + elif key in ['model_writer']: + model_writer = value + else: + new_kwargs[key] = _encode(value) + + if model_writer: + if sentence_iterator: + model_proto = SentencePieceTrainer._TrainFromMap4(new_kwargs, + sentence_iterator) + else: + model_proto = SentencePieceTrainer._TrainFromMap3(new_kwargs) + model_writer.write(model_proto) + else: + if sentence_iterator: + return SentencePieceTrainer._TrainFromMap2(new_kwargs, sentence_iterator) + else: + return SentencePieceTrainer._TrainFromMap(new_kwargs) + + return None +} } %typemap(out) std::vector<int> { @@ -439,227 +670,6 @@ from io import StringIO from io import BytesIO -def _sentencepiece_processor_init(self, - model_file=None, - model_proto=None, - out_type=int, - add_bos=False, - add_eos=False, - reverse=False, - enable_sampling=False, - nbest_size=-1, - alpha=0.1): - """Overwride SentencePieceProcessor.__init__ to add addtional parameters. - - Args: - model_file: The sentencepiece model file path. - model_proto: The sentencepiece model serialized proto. - out_type: output type. int or str. - add_bos: Add <s> to the result (Default = false) - add_eos: Add </s> to the result (Default = false) <s>/</s> is added after - reversing (if enabled). - reverse: Reverses the tokenized sequence (Default = false) - nbest_size: sampling parameters for unigram. Invalid for BPE-Dropout. - nbest_size = {0,1}: No sampling is performed. - nbest_size > 1: samples from the nbest_size results. - nbest_size < 0: assuming that nbest_size is infinite and samples - from the all hypothesis (lattice) using - forward-filtering-and-backward-sampling algorithm. - alpha: Soothing parameter for unigram sampling, and merge probability for - BPE-dropout. - """ - - _sentencepiece_processor_init_native(self) - self._out_type = out_type - self._add_bos = add_bos - self._add_eos = add_eos - self._reverse = reverse - self._enable_sampling = enable_sampling - self._nbest_size = nbest_size - self._alpha = alpha - if model_file or model_proto: - self.Load(model_file=model_file, model_proto=model_proto) - - -def _sentencepiece_processor_load(self, model_file=None, model_proto=None): - """Overwride SentencePieceProcessor.Load to support both model_file and model_proto. - - Args: - model_file: The sentencepiece model file path. - model_proto: The sentencepiece model serialized proto. Either `model_file` - or `model_proto` must be set. - """ - if model_file and model_proto: - raise RuntimeError('model_file and model_proto must be exclusive.') - if model_proto: - return self._LoadFromSerializedProto_native(model_proto) - return self._Load_native(model_file) - - -def _sentencepiece_processor_encode(self, - input, - out_type=None, - add_bos=None, - add_eos=None, - reverse=None, - enable_sampling=None, - nbest_size=None, - alpha=None): - """Encode text input to segmented ids or tokens. - - Args: - input: input string. accepsts list of string. - out_type: output type. int or str. - add_bos: Add <s> to the result (Default = false) - add_eos: Add </s> to the result (Default = false) <s>/</s> is added after - reversing (if enabled). - reverse: Reverses the tokenized sequence (Default = false) - nbest_size: sampling parameters for unigram. Invalid for BPE-Dropout. - nbest_size = {0,1}: No sampling is performed. - nbest_size > 1: samples from the nbest_size results. - nbest_size < 0: assuming that nbest_size is infinite and samples - from the all hypothesis (lattice) using - forward-filtering-and-backward-sampling algorithm. - alpha: Soothing parameter for unigram sampling, and merge probability for - BPE-dropout. - """ - - if out_type is None: - out_type = self._out_type - if add_bos is None: - add_bos = self._add_bos - if add_eos is None: - add_eos = self._add_eos - if reverse is None: - reverse = self._reverse - if enable_sampling is None: - enable_sampling = self._enable_sampling - if nbest_size is None: - nbest_size = self._nbest_size - if alpha is None: - alpha = self._alpha - - if enable_sampling == True and (nbest_size is None or nbest_size == 0 or - nbest_size == 1 or alpha is None or - alpha <= 0.0 or alpha > 1.0): - raise RuntimeError( - 'When enable_sampling is True, We must specify "nbest_size > 1" or "nbest_size = -1", ' - 'and "0.0 < alpha < 1.0". "nbest_size = -1" is enabled only on unigram mode and ' - 'samples from all candidates on the lattice instead of nbest segmentations. ' - ) - - def _encode(text): - if out_type is int: - if enable_sampling: - result = self.SampleEncodeAsIds(text, nbest_size, alpha) - else: - result = self.EncodeAsIds(text) - else: - if enable_sampling: - result = self.SampleEncodeAsPieces(text, nbest_size, alpha) - else: - result = self.EncodeAsPieces(text) - - if reverse: - result.reverse() - if add_bos: - if out_type is int: - result = [self.bos_id()] + result - else: - result = [self.IdToPiece(self.bos_id())] + result - - if add_eos: - if out_type is int: - result = result + [self.eos_id()] - else: - result = result + [self.IdToPiece(self.eos_id())] - - return result - - if type(input) is list: - return [_encode(n) for n in input] - - return _encode(input) - - -def _sentencepiece_processor_decode(self, input): - """Decode processed id or token sequences.""" - - if not input: - return self.DecodeIds([]) - elif type(input) is int: - return self.DecodeIds([input]) - elif type(input) is str: - return self.DecodePieces([input]) - - def _decode(input): - if not input: - return self.DecodeIds([]) - if type(input[0]) is int: - return self.DecodeIds(input) - return self.DecodePieces(input) - - if type(input[0]) is list: - return [_decode(n) for n in input] - - return _decode(input) - -def _sentencepiece_trainer_train(arg=None, **kwargs): - """Train Sentencepiece model. Accept both kwargs and legacy string arg.""" - if arg is not None and type(arg) is str: - return SentencePieceTrainer.TrainFromString(arg) - - def _encode(value): - """Encode value to CSV..""" - if type(value) is list: - if sys.version_info[0] == 3: - f = StringIO() - else: - f = BytesIO() - writer = csv.writer(f, lineterminator='') - writer.writerow([str(v) for v in value]) - return f.getvalue() - else: - return str(value) - - sentence_iterator = None - model_writer = None - new_kwargs = {} - for key, value in kwargs.items(): - if key in ['sentence_iterator', 'sentence_reader']: - sentence_iterator = value - elif key in ['model_writer']: - model_writer = value - else: - new_kwargs[key] = _encode(value) - - if model_writer: - if sentence_iterator: - model_proto = SentencePieceTrainer.TrainFromMap4(new_kwargs, - sentence_iterator) - else: - model_proto = SentencePieceTrainer.TrainFromMap3(new_kwargs) - model_writer.write(model_proto) - else: - if sentence_iterator: - return SentencePieceTrainer.TrainFromMap2(new_kwargs, sentence_iterator) - else: - return SentencePieceTrainer.TrainFromMap(new_kwargs) - - return None - - -def _save_native(classname): - """Stores the origina method as _{method_name}_native.""" - - native_map = {} - for name, method in classname.__dict__.items(): - if name[0] != '_': - native_map[('_%s_native' % name)] = method - for k, v in native_map.items(): - setattr(classname, k, v) - - def _add_snake_case(classname): """Added snake_cased method from CammelCased method.""" @@ -675,8 +685,7 @@ def _add_snake_case(classname): def _batchnize(classname, name): """Enables batch request for the method classname.name.""" - - func = getattr(classname, '_%s_native' % name, None) + func = getattr(classname, name, None) def _batched_func(self, arg): if type(arg) is list: @@ -687,17 +696,11 @@ def _batchnize(classname, name): setattr(classname, name, _batched_func) -_save_native(SentencePieceProcessor) _sentencepiece_processor_init_native = SentencePieceProcessor.__init__ -setattr(SentencePieceProcessor, 'Encode', _sentencepiece_processor_encode) -setattr(SentencePieceProcessor, 'Tokenize', _sentencepiece_processor_encode) -setattr(SentencePieceProcessor, 'Decode', _sentencepiece_processor_decode) -setattr(SentencePieceProcessor, 'Detokenize', _sentencepiece_processor_decode) -setattr(SentencePieceProcessor, 'Load', _sentencepiece_processor_load) -setattr(SentencePieceProcessor, '__init__', _sentencepiece_processor_init) -setattr(SentencePieceProcessor, 'vocab_size', SentencePieceProcessor.GetPieceSize) -setattr(SentencePieceProcessor, 'piece_size', SentencePieceProcessor.GetPieceSize) -setattr(SentencePieceTrainer, 'Train', staticmethod(_sentencepiece_trainer_train)) +setattr(SentencePieceProcessor, '__init__', SentencePieceProcessor.Init) + +SentencePieceProcessor.Tokenize = SentencePieceProcessor.Encode +SentencePieceProcessor.Detokenize = SentencePieceProcessor.Decode for m in [ 'PieceToId', 'IdToPiece', 'GetScore', 'IsUnknown', 'IsControl', 'IsUnused', diff --git a/python/sentencepiece.py b/python/sentencepiece.py index 7fd116e..8f1b038 100644 --- a/python/sentencepiece.py +++ b/python/sentencepiece.py @@ -71,12 +71,6 @@ class SentencePieceProcessor(object): _sentencepiece.SentencePieceProcessor_swiginit(self, _sentencepiece.new_SentencePieceProcessor()) __swig_destroy__ = _sentencepiece.delete_SentencePieceProcessor - def Load(self, filename): - return _sentencepiece.SentencePieceProcessor_Load(self, filename) - - def LoadOrDie(self, filename): - return _sentencepiece.SentencePieceProcessor_LoadOrDie(self, filename) - def LoadFromSerializedProto(self, serialized): return _sentencepiece.SentencePieceProcessor_LoadFromSerializedProto(self, serialized) @@ -179,288 +173,301 @@ class SentencePieceProcessor(object): def serialized_model_proto(self): return _sentencepiece.SentencePieceProcessor_serialized_model_proto(self) - def __len__(self): - return _sentencepiece.SentencePieceProcessor___len__(self) + def LoadFromFile(self, arg): + return _sentencepiece.SentencePieceProcessor_LoadFromFile(self, arg) + + def Init(self, + model_file=None, + model_proto=None, + out_type=int, + add_bos=False, + add_eos=False, + reverse=False, + enable_sampling=False, + nbest_size=-1, + alpha=0.1): + """Initialzie sentencepieceProcessor. + + Args: + model_file: The sentencepiece model file path. + model_proto: The sentencepiece model serialized proto. + out_type: output type. int or str. + add_bos: Add <s> to the result (Default = false) + add_eos: Add </s> to the result (Default = false) <s>/</s> is added after + reversing (if enabled). + reverse: Reverses the tokenized sequence (Default = false) + nbest_size: sampling parameters for unigram. Invalid for BPE-Dropout. + nbest_size = {0,1}: No sampling is performed. + nbest_size > 1: samples from the nbest_size results. + nbest_size < 0: assuming that nbest_size is infinite and samples + from the all hypothesis (lattice) using + forward-filtering-and-backward-sampling algorithm. + alpha: Soothing parameter for unigram sampling, and merge probability for + BPE-dropout. + """ + + _sentencepiece_processor_init_native(self) + self._out_type = out_type + self._add_bos = add_bos + self._add_eos = add_eos + self._reverse = reverse + self._enable_sampling = enable_sampling + self._nbest_size = nbest_size + self._alpha = alpha + if model_file or model_proto: + self.Load(model_file=model_file, model_proto=model_proto) + + + def Encode(self, + input, + out_type=None, + add_bos=None, + add_eos=None, + reverse=None, + enable_sampling=None, + nbest_size=None, + alpha=None): + """Encode text input to segmented ids or tokens. + + Args: + input: input string. accepsts list of string. + out_type: output type. int or str. + add_bos: Add <s> to the result (Default = false) + add_eos: Add </s> to the result (Default = false) <s>/</s> is added after + reversing (if enabled). + reverse: Reverses the tokenized sequence (Default = false) + nbest_size: sampling parameters for unigram. Invalid for BPE-Dropout. + nbest_size = {0,1}: No sampling is performed. + nbest_size > 1: samples from the nbest_size results. + nbest_size < 0: assuming that nbest_size is infinite and samples + from the all hypothesis (lattice) using + forward-filtering-and-backward-sampling algorithm. + alpha: Soothing parameter for unigram sampling, and merge probability for + BPE-dropout. + """ + + if out_type is None: + out_type = self._out_type + if add_bos is None: + add_bos = self._add_bos + if add_eos is None: + add_eos = self._add_eos + if reverse is None: + reverse = self._reverse + if enable_sampling is None: + enable_sampling = self._enable_sampling + if nbest_size is None: + nbest_size = self._nbest_size + if alpha is None: + alpha = self._alpha + + if enable_sampling == True and (nbest_size is None or nbest_size == 0 or + nbest_size == 1 or alpha is None or + alpha <= 0.0 or alpha > 1.0): + raise RuntimeError( + 'When enable_sampling is True, We must specify "nbest_size > 1" or "nbest_size = -1", ' + 'and "0.0 < alpha < 1.0". "nbest_size = -1" is enabled only on unigram mode and ' + 'samples from all candidates on the lattice instead of nbest segmentations. ' + ) + + def _encode(text): + if out_type is int: + if enable_sampling: + result = self.SampleEncodeAsIds(text, nbest_size, alpha) + else: + result = self.EncodeAsIds(text) + else: + if enable_sampling: + result = self.SampleEncodeAsPieces(text, nbest_size, alpha) + else: + result = self.EncodeAsPieces(text) - def __getitem__(self, key): - return _sentencepiece.SentencePieceProcessor___getitem__(self, key) + if reverse: + result.reverse() + if add_bos: + if out_type is int: + result = [self.bos_id()] + result + else: + result = [self.IdToPiece(self.bos_id())] + result -# Register SentencePieceProcessor in _sentencepiece: -_sentencepiece.SentencePieceProcessor_swigregister(SentencePieceProcessor) + if add_eos: + if out_type is int: + result = result + [self.eos_id()] + else: + result = result + [self.IdToPiece(self.eos_id())] -class SentencePieceTrainer(object): - thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag") + return result - def __init__(self, *args, **kwargs): - raise AttributeError("No constructor defined") - __repr__ = _swig_repr + if type(input) is list: + return [_encode(n) for n in input] - @staticmethod - def TrainFromString(arg): - return _sentencepiece.SentencePieceTrainer_TrainFromString(arg) + return _encode(input) - @staticmethod - def TrainFromMap(args): - return _sentencepiece.SentencePieceTrainer_TrainFromMap(args) - @staticmethod - def TrainFromMap2(args, iter): - return _sentencepiece.SentencePieceTrainer_TrainFromMap2(args, iter) + def Decode(self, input): + """Decode processed id or token sequences.""" - @staticmethod - def TrainFromMap3(args): - return _sentencepiece.SentencePieceTrainer_TrainFromMap3(args) + if not input: + return self.DecodeIds([]) + elif type(input) is int: + return self.DecodeIds([input]) + elif type(input) is str: + return self.DecodePieces([input]) - @staticmethod - def TrainFromMap4(args, iter): - return _sentencepiece.SentencePieceTrainer_TrainFromMap4(args, iter) + def _decode(input): + if not input: + return self.DecodeIds([]) + if type(input[0]) is int: + return self.DecodeIds(input) + return self.DecodePieces(input) -# Register SentencePieceTrainer in _sentencepiece: -_sentencepiece.SentencePieceTrainer_swigregister(SentencePieceTrainer) + if type(input[0]) is list: + return [_decode(n) for n in input] -def SentencePieceTrainer_TrainFromString(arg): - return _sentencepiece.SentencePieceTrainer_TrainFromString(arg) + return _decode(input) -def SentencePieceTrainer_TrainFromMap(args): - return _sentencepiece.SentencePieceTrainer_TrainFromMap(args) -def SentencePieceTrainer_TrainFromMap2(args, iter): - return _sentencepiece.SentencePieceTrainer_TrainFromMap2(args, iter) + def piece_size(self): + return self.GetPieceSize() -def SentencePieceTrainer_TrainFromMap3(args): - return _sentencepiece.SentencePieceTrainer_TrainFromMap3(args) -def SentencePieceTrainer_TrainFromMap4(args, iter): - return _sentencepiece.SentencePieceTrainer_TrainFromMap4(args, iter) + def vocab_size(self): + return self.GetPieceSize() + def __getstate__(self): + return self.serialized_model_proto() -import re -import csv -import sys -from io import StringIO -from io import BytesIO + def __setstate__(self, serialized_model_proto): + self.__init__() + self.LoadFromSerializedProto(serialized_model_proto) -def _sentencepiece_processor_init(self, - model_file=None, - model_proto=None, - out_type=int, - add_bos=False, - add_eos=False, - reverse=False, - enable_sampling=False, - nbest_size=-1, - alpha=0.1): - """Overwride SentencePieceProcessor.__init__ to add addtional parameters. - - Args: - model_file: The sentencepiece model file path. - model_proto: The sentencepiece model serialized proto. - out_type: output type. int or str. - add_bos: Add <s> to the result (Default = false) - add_eos: Add </s> to the result (Default = false) <s>/</s> is added after - reversing (if enabled). - reverse: Reverses the tokenized sequence (Default = false) - nbest_size: sampling parameters for unigram. Invalid for BPE-Dropout. - nbest_size = {0,1}: No sampling is performed. - nbest_size > 1: samples from the nbest_size results. - nbest_size < 0: assuming that nbest_size is infinite and samples - from the all hypothesis (lattice) using - forward-filtering-and-backward-sampling algorithm. - alpha: Soothing parameter for unigram sampling, and merge probability for - BPE-dropout. - """ - - _sentencepiece_processor_init_native(self) - self._out_type = out_type - self._add_bos = add_bos - self._add_eos = add_eos - self._reverse = reverse - self._enable_sampling = enable_sampling - self._nbest_size = nbest_size - self._alpha = alpha - if model_file or model_proto: - self.Load(model_file=model_file, model_proto=model_proto) - - -def _sentencepiece_processor_load(self, model_file=None, model_proto=None): - """Overwride SentencePieceProcessor.Load to support both model_file and model_proto. - - Args: - model_file: The sentencepiece model file path. - model_proto: The sentencepiece model serialized proto. Either `model_file` - or `model_proto` must be set. - """ - if model_file and model_proto: - raise RuntimeError('model_file and model_proto must be exclusive.') - if model_proto: - return self._LoadFromSerializedProto_native(model_proto) - return self._Load_native(model_file) - - -def _sentencepiece_processor_encode(self, - input, - out_type=None, - add_bos=None, - add_eos=None, - reverse=None, - enable_sampling=None, - nbest_size=None, - alpha=None): - """Encode text input to segmented ids or tokens. - - Args: - input: input string. accepsts list of string. - out_type: output type. int or str. - add_bos: Add <s> to the result (Default = false) - add_eos: Add </s> to the result (Default = false) <s>/</s> is added after - reversing (if enabled). - reverse: Reverses the tokenized sequence (Default = false) - nbest_size: sampling parameters for unigram. Invalid for BPE-Dropout. - nbest_size = {0,1}: No sampling is performed. - nbest_size > 1: samples from the nbest_size results. - nbest_size < 0: assuming that nbest_size is infinite and samples - from the all hypothesis (lattice) using - forward-filtering-and-backward-sampling algorithm. - alpha: Soothing parameter for unigram sampling, and merge probability for - BPE-dropout. - """ - - if out_type is None: - out_type = self._out_type - if add_bos is None: - add_bos = self._add_bos - if add_eos is None: - add_eos = self._add_eos - if reverse is None: - reverse = self._reverse - if enable_sampling is None: - enable_sampling = self._enable_sampling - if nbest_size is None: - nbest_size = self._nbest_size - if alpha is None: - alpha = self._alpha - - if enable_sampling == True and (nbest_size is None or nbest_size == 0 or - nbest_size == 1 or alpha is None or - alpha <= 0.0 or alpha > 1.0): - raise RuntimeError( - 'When enable_sampling is True, We must specify "nbest_size > 1" or "nbest_size = -1", ' - 'and "0.0 < alpha < 1.0". "nbest_size = -1" is enabled only on unigram mode and ' - 'samples from all candidates on the lattice instead of nbest segmentations. ' - ) - - def _encode(text): - if out_type is int: - if enable_sampling: - result = self.SampleEncodeAsIds(text, nbest_size, alpha) - else: - result = self.EncodeAsIds(text) - else: - if enable_sampling: - result = self.SampleEncodeAsPieces(text, nbest_size, alpha) - else: - result = self.EncodeAsPieces(text) - if reverse: - result.reverse() - if add_bos: - if out_type is int: - result = [self.bos_id()] + result - else: - result = [self.IdToPiece(self.bos_id())] + result + def __len__(self): + return self.GetPieceSize() + + + def __getitem__(self, piece): + return self.PieceToId(piece) - if add_eos: - if out_type is int: - result = result + [self.eos_id()] - else: - result = result + [self.IdToPiece(self.eos_id())] - return result + def Load(self, model_file=None, model_proto=None): + """Overwride SentencePieceProcessor.Load to support both model_file and model_proto. - if type(input) is list: - return [_encode(n) for n in input] + Args: + model_file: The sentencepiece model file path. + model_proto: The sentencepiece model serialized proto. Either `model_file` + or `model_proto` must be set. + """ + if model_file and model_proto: + raise RuntimeError('model_file and model_proto must be exclusive.') + if model_proto: + return self.LoadFromSerializedProto(model_proto) + return self.LoadFromFile(model_file) + + +# Register SentencePieceProcessor in _sentencepiece: +_sentencepiece.SentencePieceProcessor_swigregister(SentencePieceProcessor) - return _encode(input) +class SentencePieceTrainer(object): + thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag") + def __init__(self, *args, **kwargs): + raise AttributeError("No constructor defined") + __repr__ = _swig_repr -def _sentencepiece_processor_decode(self, input): - """Decode processed id or token sequences.""" + @staticmethod + def _TrainFromString(arg): + return _sentencepiece.SentencePieceTrainer__TrainFromString(arg) - if not input: - return self.DecodeIds([]) - elif type(input) is int: - return self.DecodeIds([input]) - elif type(input) is str: - return self.DecodePieces([input]) + @staticmethod + def _TrainFromMap(args): + return _sentencepiece.SentencePieceTrainer__TrainFromMap(args) - def _decode(input): - if not input: - return self.DecodeIds([]) - if type(input[0]) is int: - return self.DecodeIds(input) - return self.DecodePieces(input) + @staticmethod + def _TrainFromMap2(args, iter): + return _sentencepiece.SentencePieceTrainer__TrainFromMap2(args, iter) - if type(input[0]) is list: - return [_decode(n) for n in input] + @staticmethod + def _TrainFromMap3(args): + return _sentencepiece.SentencePieceTrainer__TrainFromMap3(args) - return _decode(input) + @staticmethod + def _TrainFromMap4(args, iter): + return _sentencepiece.SentencePieceTrainer__TrainFromMap4(args, iter) -def _sentencepiece_trainer_train(arg=None, **kwargs): - """Train Sentencepiece model. Accept both kwargs and legacy string arg.""" - if arg is not None and type(arg) is str: - return SentencePieceTrainer.TrainFromString(arg) + @staticmethod + def Train(arg=None, **kwargs): + """Train Sentencepiece model. Accept both kwargs and legacy string arg.""" + if arg is not None and type(arg) is str: + return SentencePieceTrainer._TrainFromString(arg) + + def _encode(value): + """Encode value to CSV..""" + if type(value) is list: + if sys.version_info[0] == 3: + f = StringIO() + else: + f = BytesIO() + writer = csv.writer(f, lineterminator='') + writer.writerow([str(v) for v in value]) + return f.getvalue() + else: + return str(value) + + sentence_iterator = None + model_writer = None + new_kwargs = {} + for key, value in kwargs.items(): + if key in ['sentence_iterator', 'sentence_reader']: + sentence_iterator = value + elif key in ['model_writer']: + model_writer = value + else: + new_kwargs[key] = _encode(value) - def _encode(value): - """Encode value to CSV..""" - if type(value) is list: - if sys.version_info[0] == 3: - f = StringIO() + if model_writer: + if sentence_iterator: + model_proto = SentencePieceTrainer._TrainFromMap4(new_kwargs, + sentence_iterator) + else: + model_proto = SentencePieceTrainer._TrainFromMap3(new_kwargs) + model_writer.write(model_proto) else: - f = BytesIO() - writer = csv.writer(f, lineterminator='') - writer.writerow([str(v) for v in value]) - return f.getvalue() - else: - return str(value) - - sentence_iterator = None - model_writer = None - new_kwargs = {} - for key, value in kwargs.items(): - if key in ['sentence_iterator', 'sentence_reader']: - sentence_iterator = value - elif key in ['model_writer']: - model_writer = value - else: - new_kwargs[key] = _encode(value) + if sentence_iterator: + return SentencePieceTrainer._TrainFromMap2(new_kwargs, sentence_iterator) + else: + return SentencePieceTrainer._TrainFromMap(new_kwargs) - if model_writer: - if sentence_iterator: - model_proto = SentencePieceTrainer.TrainFromMap4(new_kwargs, - sentence_iterator) - else: - model_proto = SentencePieceTrainer.TrainFromMap3(new_kwargs) - model_writer.write(model_proto) - else: - if sentence_iterator: - return SentencePieceTrainer.TrainFromMap2(new_kwargs, sentence_iterator) - else: - return SentencePieceTrainer.TrainFromMap(new_kwargs) + return None - return None +# Register SentencePieceTrainer in _sentencepiece: +_sentencepiece.SentencePieceTrainer_swigregister(SentencePieceTrainer) -def _save_native(classname): - """Stores the origina method as _{method_name}_native.""" +def SentencePieceTrainer__TrainFromString(arg): + return _sentencepiece.SentencePieceTrainer__TrainFromString(arg) - native_map = {} - for name, method in classname.__dict__.items(): - if name[0] != '_': - native_map[('_%s_native' % name)] = method - for k, v in native_map.items(): - setattr(classname, k, v) +def SentencePieceTrainer__TrainFromMap(args): + return _sentencepiece.SentencePieceTrainer__TrainFromMap(args) + +def SentencePieceTrainer__TrainFromMap2(args, iter): + return _sentencepiece.SentencePieceTrainer__TrainFromMap2(args, iter) + +def SentencePieceTrainer__TrainFromMap3(args): + return _sentencepiece.SentencePieceTrainer__TrainFromMap3(args) + +def SentencePieceTrainer__TrainFromMap4(args, iter): + return _sentencepiece.SentencePieceTrainer__TrainFromMap4(args, iter) + + + +import re +import csv +import sys +from io import StringIO +from io import BytesIO def _add_snake_case(classname): @@ -478,8 +485,7 @@ def _add_snake_case(classname): def _batchnize(classname, name): """Enables batch request for the method classname.name.""" - - func = getattr(classname, '_%s_native' % name, None) + func = getattr(classname, name, None) def _batched_func(self, arg): if type(arg) is list: @@ -490,17 +496,11 @@ def _batchnize(classname, name): setattr(classname, name, _batched_func) -_save_native(SentencePieceProcessor) _sentencepiece_processor_init_native = SentencePieceProcessor.__init__ -setattr(SentencePieceProcessor, 'Encode', _sentencepiece_processor_encode) -setattr(SentencePieceProcessor, 'Tokenize', _sentencepiece_processor_encode) -setattr(SentencePieceProcessor, 'Decode', _sentencepiece_processor_decode) -setattr(SentencePieceProcessor, 'Detokenize', _sentencepiece_processor_decode) -setattr(SentencePieceProcessor, 'Load', _sentencepiece_processor_load) -setattr(SentencePieceProcessor, '__init__', _sentencepiece_processor_init) -setattr(SentencePieceProcessor, 'vocab_size', SentencePieceProcessor.GetPieceSize) -setattr(SentencePieceProcessor, 'piece_size', SentencePieceProcessor.GetPieceSize) -setattr(SentencePieceTrainer, 'Train', staticmethod(_sentencepiece_trainer_train)) +setattr(SentencePieceProcessor, '__init__', SentencePieceProcessor.Init) + +SentencePieceProcessor.Tokenize = SentencePieceProcessor.Encode +SentencePieceProcessor.Detokenize = SentencePieceProcessor.Decode for m in [ 'PieceToId', 'IdToPiece', 'GetScore', 'IsUnknown', 'IsControl', 'IsUnused', diff --git a/python/sentencepiece_wrap.cxx b/python/sentencepiece_wrap.cxx index 12162b6..e47b780 100644 --- a/python/sentencepiece_wrap.cxx +++ b/python/sentencepiece_wrap.cxx @@ -3280,34 +3280,31 @@ SWIGINTERNINLINE PyObject* return PyBool_FromLong(value ? 1 : 0); } -SWIGINTERN int sentencepiece_SentencePieceProcessor___len__(sentencepiece::SentencePieceProcessor *self){ - return self->GetPieceSize(); +SWIGINTERN sentencepiece::util::Status sentencepiece_SentencePieceProcessor_LoadFromFile(sentencepiece::SentencePieceProcessor *self,absl::string_view arg){ + return self->Load(arg); } -SWIGINTERN int sentencepiece_SentencePieceProcessor___getitem__(sentencepiece::SentencePieceProcessor const *self,absl::string_view key){ - return self->PieceToId(key); - } -SWIGINTERN void sentencepiece_SentencePieceTrainer_TrainFromString(absl::string_view arg){ +SWIGINTERN void sentencepiece_SentencePieceTrainer__TrainFromString(absl::string_view arg){ const auto _status = sentencepiece::SentencePieceTrainer::Train(arg); if (!_status.ok()) throw _status; return; } -SWIGINTERN void sentencepiece_SentencePieceTrainer_TrainFromMap(std::map< std::string,std::string > const &args){ +SWIGINTERN void sentencepiece_SentencePieceTrainer__TrainFromMap(std::map< std::string,std::string > const &args){ const auto _status = sentencepiece::SentencePieceTrainer::Train(args); if (!_status.ok()) throw _status; return; } -SWIGINTERN void sentencepiece_SentencePieceTrainer_TrainFromMap2(std::map< std::string,std::string > const &args,sentencepiece::SentenceIterator *iter){ +SWIGINTERN void sentencepiece_SentencePieceTrainer__TrainFromMap2(std::map< std::string,std::string > const &args,sentencepiece::SentenceIterator *iter){ const auto _status = sentencepiece::SentencePieceTrainer::Train(args, iter); if (!_status.ok()) throw _status; return; } -SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceTrainer_TrainFromMap3(std::map< std::string,std::string > const &args){ +SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceTrainer__TrainFromMap3(std::map< std::string,std::string > const &args){ sentencepiece::util::bytes model_proto; const auto _status = sentencepiece::SentencePieceTrainer::Train(args, nullptr, &model_proto); if (!_status.ok()) throw _status; return model_proto; } -SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceTrainer_TrainFromMap4(std::map< std::string,std::string > const &args,sentencepiece::SentenceIterator *iter){ +SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceTrainer__TrainFromMap4(std::map< std::string,std::string > const &args,sentencepiece::SentenceIterator *iter){ sentencepiece::util::bytes model_proto; const auto _status = sentencepiece::SentencePieceTrainer::Train(args, iter, &model_proto); if (!_status.ok()) throw _status; @@ -3367,90 +3364,6 @@ fail: } -SWIGINTERN PyObject *_wrap_SentencePieceProcessor_Load(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { - PyObject *resultobj = 0; - sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; - absl::string_view arg2 ; - void *argp1 = 0 ; - int res1 = 0 ; - PyObject *swig_obj[2] ; - sentencepiece::util::Status result; - - if (!SWIG_Python_UnpackTuple(args, "SentencePieceProcessor_Load", 2, 2, swig_obj)) SWIG_fail; - res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 ); - if (!SWIG_IsOK(res1)) { - SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor_Load" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor *""'"); - } - arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1); - { - const PyInputString ustring(swig_obj[1]); - if (!ustring.IsAvalable()) { - PyErr_SetString(PyExc_TypeError, "not a string"); - SWIG_fail; - } - resultobj = ustring.input_type(); - arg2 = absl::string_view(ustring.data(), ustring.size()); - } - { - try { - result = (arg1)->Load(arg2); - ReleaseResultObject(resultobj); - } - catch (const sentencepiece::util::Status &status) { - SWIG_exception(ToSwigError(status.code()), status.ToString().c_str()); - } - } - { - if (!(&result)->ok()) { - SWIG_exception(ToSwigError((&result)->code()), (&result)->ToString().c_str()); - } - resultobj = SWIG_From_bool((&result)->ok()); - } - return resultobj; -fail: - return NULL; -} - - -SWIGINTERN PyObject *_wrap_SentencePieceProcessor_LoadOrDie(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { - PyObject *resultobj = 0; - sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; - absl::string_view arg2 ; - void *argp1 = 0 ; - int res1 = 0 ; - PyObject *swig_obj[2] ; - - if (!SWIG_Python_UnpackTuple(args, "SentencePieceProcessor_LoadOrDie", 2, 2, swig_obj)) SWIG_fail; - res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 ); - if (!SWIG_IsOK(res1)) { - SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor_LoadOrDie" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor *""'"); - } - arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1); - { - const PyInputString ustring(swig_obj[1]); - if (!ustring.IsAvalable()) { - PyErr_SetString(PyExc_TypeError, "not a string"); - SWIG_fail; - } - resultobj = ustring.input_type(); - arg2 = absl::string_view(ustring.data(), ustring.size()); - } - { - try { - (arg1)->LoadOrDie(arg2); - ReleaseResultObject(resultobj); - } - catch (const sentencepiece::util::Status &status) { - SWIG_exception(ToSwigError(status.code()), status.ToString().c_str()); - } - } - resultobj = SWIG_Py_Void(); - return resultobj; -fail: - return NULL; -} - - SWIGINTERN PyObject *_wrap_SentencePieceProcessor_LoadFromSerializedProto(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { PyObject *resultobj = 0; sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; @@ -4990,50 +4903,19 @@ fail: } -SWIGINTERN PyObject *_wrap_SentencePieceProcessor___len__(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { - PyObject *resultobj = 0; - sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; - void *argp1 = 0 ; - int res1 = 0 ; - PyObject *swig_obj[1] ; - int result; - - if (!args) SWIG_fail; - swig_obj[0] = args; - res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 ); - if (!SWIG_IsOK(res1)) { - SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor___len__" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor *""'"); - } - arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1); - { - try { - result = (int)sentencepiece_SentencePieceProcessor___len__(arg1); - ReleaseResultObject(resultobj); - } - catch (const sentencepiece::util::Status &status) { - SWIG_exception(ToSwigError(status.code()), status.ToString().c_str()); - } - } - resultobj = SWIG_From_int(static_cast< int >(result)); - return resultobj; -fail: - return NULL; -} - - -SWIGINTERN PyObject *_wrap_SentencePieceProcessor___getitem__(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { +SWIGINTERN PyObject *_wrap_SentencePieceProcessor_LoadFromFile(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { PyObject *resultobj = 0; sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; absl::string_view arg2 ; void *argp1 = 0 ; int res1 = 0 ; PyObject *swig_obj[2] ; - int result; + sentencepiece::util::Status result; - if (!SWIG_Python_UnpackTuple(args, "SentencePieceProcessor___getitem__", 2, 2, swig_obj)) SWIG_fail; + if (!SWIG_Python_UnpackTuple(args, "SentencePieceProcessor_LoadFromFile", 2, 2, swig_obj)) SWIG_fail; res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 ); if (!SWIG_IsOK(res1)) { - SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor___getitem__" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor const *""'"); + SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor_LoadFromFile" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor *""'"); } arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1); { @@ -5047,14 +4929,19 @@ SWIGINTERN PyObject *_wrap_SentencePieceProcessor___getitem__(PyObject *SWIGUNUS } { try { - result = (int)sentencepiece_SentencePieceProcessor___getitem__((sentencepiece::SentencePieceProcessor const *)arg1,arg2); + result = sentencepiece_SentencePieceProcessor_LoadFromFile(arg1,arg2); ReleaseResultObject(resultobj); } catch (const sentencepiece::util::Status &status) { SWIG_exception(ToSwigError(status.code()), status.ToString().c_str()); } } - resultobj = SWIG_From_int(static_cast< int >(result)); + { + if (!(&result)->ok()) { + SWIG_exception(ToSwigError((&result)->code()), (&result)->ToString().c_str()); + } + resultobj = SWIG_From_bool((&result)->ok()); + } return resultobj; fail: return NULL; @@ -5072,7 +4959,7 @@ SWIGINTERN PyObject *SentencePieceProcessor_swiginit(PyObject *SWIGUNUSEDPARM(se return SWIG_Python_InitShadowInstance(args); } -SWIGINTERN PyObject *_wrap_SentencePieceTrainer_TrainFromString(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { +SWIGINTERN PyObject *_wrap_SentencePieceTrainer__TrainFromString(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { PyObject *resultobj = 0; absl::string_view arg1 ; PyObject *swig_obj[1] ; @@ -5090,7 +4977,7 @@ SWIGINTERN PyObject *_wrap_SentencePieceTrainer_TrainFromString(PyObject *SWIGUN } { try { - sentencepiece_SentencePieceTrainer_TrainFromString(arg1); + sentencepiece_SentencePieceTrainer__TrainFromString(arg1); ReleaseResultObject(resultobj); } catch (const sentencepiece::util::Status &status) { @@ -5104,7 +4991,7 @@ fail: } -SWIGINTERN PyObject *_wrap_SentencePieceTrainer_TrainFromMap(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { +SWIGINTERN PyObject *_wrap_SentencePieceTrainer__TrainFromMap(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { PyObject *resultobj = 0; std::map< std::string,std::string > *arg1 = 0 ; PyObject *swig_obj[1] ; @@ -5137,7 +5024,7 @@ SWIGINTERN PyObject *_wrap_SentencePieceTrainer_TrainFromMap(PyObject *SWIGUNUSE } { try { - sentencepiece_SentencePieceTrainer_TrainFromMap((std::map< std::string,std::string > const &)*arg1); + sentencepiece_SentencePieceTrainer__TrainFromMap((std::map< std::string,std::string > const &)*arg1); ReleaseResultObject(resultobj); } catch (const sentencepiece::util::Status &status) { @@ -5157,13 +5044,13 @@ fail: } -SWIGINTERN PyObject *_wrap_SentencePieceTrainer_TrainFromMap2(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { +SWIGINTERN PyObject *_wrap_SentencePieceTrainer__TrainFromMap2(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { PyObject *resultobj = 0; std::map< std::string,std::string > *arg1 = 0 ; sentencepiece::SentenceIterator *arg2 = (sentencepiece::SentenceIterator *) 0 ; PyObject *swig_obj[2] ; - if (!SWIG_Python_UnpackTuple(args, "SentencePieceTrainer_TrainFromMap2", 2, 2, swig_obj)) SWIG_fail; + if (!SWIG_Python_UnpackTuple(args, "SentencePieceTrainer__TrainFromMap2", 2, 2, swig_obj)) SWIG_fail; { std::map<std::string, std::string> *out = nullptr; if (PyDict_Check(swig_obj[0])) { @@ -5200,7 +5087,7 @@ SWIGINTERN PyObject *_wrap_SentencePieceTrainer_TrainFromMap2(PyObject *SWIGUNUS } { try { - sentencepiece_SentencePieceTrainer_TrainFromMap2((std::map< std::string,std::string > const &)*arg1,arg2); + sentencepiece_SentencePieceTrainer__TrainFromMap2((std::map< std::string,std::string > const &)*arg1,arg2); ReleaseResultObject(resultobj); } catch (const sentencepiece::util::Status &status) { @@ -5226,7 +5113,7 @@ fail: } -SWIGINTERN PyObject *_wrap_SentencePieceTrainer_TrainFromMap3(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { +SWIGINTERN PyObject *_wrap_SentencePieceTrainer__TrainFromMap3(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { PyObject *resultobj = 0; std::map< std::string,std::string > *arg1 = 0 ; PyObject *swig_obj[1] ; @@ -5260,7 +5147,7 @@ SWIGINTERN PyObject *_wrap_SentencePieceTrainer_TrainFromMap3(PyObject *SWIGUNUS } { try { - result = sentencepiece_SentencePieceTrainer_TrainFromMap3((std::map< std::string,std::string > const &)*arg1); + result = sentencepiece_SentencePieceTrainer__TrainFromMap3((std::map< std::string,std::string > const &)*arg1); ReleaseResultObject(resultobj); } catch (const sentencepiece::util::Status &status) { @@ -5282,14 +5169,14 @@ fail: } -SWIGINTERN PyObject *_wrap_SentencePieceTrainer_TrainFromMap4(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { +SWIGINTERN PyObject *_wrap_SentencePieceTrainer__TrainFromMap4(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { PyObject *resultobj = 0; std::map< std::string,std::string > *arg1 = 0 ; sentencepiece::SentenceIterator *arg2 = (sentencepiece::SentenceIterator *) 0 ; PyObject *swig_obj[2] ; sentencepiece::util::bytes result; - if (!SWIG_Python_UnpackTuple(args, "SentencePieceTrainer_TrainFromMap4", 2, 2, swig_obj)) SWIG_fail; + if (!SWIG_Python_UnpackTuple(args, "SentencePieceTrainer__TrainFromMap4", 2, 2, swig_obj)) SWIG_fail; { std::map<std::string, std::string> *out = nullptr; if (PyDict_Check(swig_obj[0])) { @@ -5326,7 +5213,7 @@ SWIGINTERN PyObject *_wrap_SentencePieceTrainer_TrainFromMap4(PyObject *SWIGUNUS } { try { - result = sentencepiece_SentencePieceTrainer_TrainFromMap4((std::map< std::string,std::string > const &)*arg1,arg2); + result = sentencepiece_SentencePieceTrainer__TrainFromMap4((std::map< std::string,std::string > const &)*arg1,arg2); ReleaseResultObject(resultobj); } catch (const sentencepiece::util::Status &status) { @@ -5365,8 +5252,6 @@ static PyMethodDef SwigMethods[] = { { "SWIG_PyInstanceMethod_New", SWIG_PyInstanceMethod_New, METH_O, NULL}, { "new_SentencePieceProcessor", _wrap_new_SentencePieceProcessor, METH_NOARGS, NULL}, { "delete_SentencePieceProcessor", _wrap_delete_SentencePieceProcessor, METH_O, NULL}, - { "SentencePieceProcessor_Load", _wrap_SentencePieceProcessor_Load, METH_VARARGS, NULL}, - { "SentencePieceProcessor_LoadOrDie", _wrap_SentencePieceProcessor_LoadOrDie, METH_VARARGS, NULL}, { "SentencePieceProcessor_LoadFromSerializedProto", _wrap_SentencePieceProcessor_LoadFromSerializedProto, METH_VARARGS, NULL}, { "SentencePieceProcessor_SetEncodeExtraOptions", _wrap_SentencePieceProcessor_SetEncodeExtraOptions, METH_VARARGS, NULL}, { "SentencePieceProcessor_SetDecodeExtraOptions", _wrap_SentencePieceProcessor_SetDecodeExtraOptions, METH_VARARGS, NULL}, @@ -5401,15 +5286,14 @@ static PyMethodDef SwigMethods[] = { { "SentencePieceProcessor_eos_id", _wrap_SentencePieceProcessor_eos_id, METH_O, NULL}, { "SentencePieceProcessor_pad_id", _wrap_SentencePieceProcessor_pad_id, METH_O, NULL}, { "SentencePieceProcessor_serialized_model_proto", _wrap_SentencePieceProcessor_serialized_model_proto, METH_O, NULL}, - { "SentencePieceProcessor___len__", _wrap_SentencePieceProcessor___len__, METH_O, NULL}, - { "SentencePieceProcessor___getitem__", _wrap_SentencePieceProcessor___getitem__, METH_VARARGS, NULL}, + { "SentencePieceProcessor_LoadFromFile", _wrap_SentencePieceProcessor_LoadFromFile, METH_VARARGS, NULL}, { "SentencePieceProcessor_swigregister", SentencePieceProcessor_swigregister, METH_O, NULL}, { "SentencePieceProcessor_swiginit", SentencePieceProcessor_swiginit, METH_VARARGS, NULL}, - { "SentencePieceTrainer_TrainFromString", _wrap_SentencePieceTrainer_TrainFromString, METH_O, NULL}, - { "SentencePieceTrainer_TrainFromMap", _wrap_SentencePieceTrainer_TrainFromMap, METH_O, NULL}, - { "SentencePieceTrainer_TrainFromMap2", _wrap_SentencePieceTrainer_TrainFromMap2, METH_VARARGS, NULL}, - { "SentencePieceTrainer_TrainFromMap3", _wrap_SentencePieceTrainer_TrainFromMap3, METH_O, NULL}, - { "SentencePieceTrainer_TrainFromMap4", _wrap_SentencePieceTrainer_TrainFromMap4, METH_VARARGS, NULL}, + { "SentencePieceTrainer__TrainFromString", _wrap_SentencePieceTrainer__TrainFromString, METH_O, NULL}, + { "SentencePieceTrainer__TrainFromMap", _wrap_SentencePieceTrainer__TrainFromMap, METH_O, NULL}, + { "SentencePieceTrainer__TrainFromMap2", _wrap_SentencePieceTrainer__TrainFromMap2, METH_VARARGS, NULL}, + { "SentencePieceTrainer__TrainFromMap3", _wrap_SentencePieceTrainer__TrainFromMap3, METH_O, NULL}, + { "SentencePieceTrainer__TrainFromMap4", _wrap_SentencePieceTrainer__TrainFromMap4, METH_VARARGS, NULL}, { "SentencePieceTrainer_swigregister", SentencePieceTrainer_swigregister, METH_O, NULL}, { NULL, NULL, 0, NULL } }; diff --git a/python/test/sentencepiece_test.py b/python/test/sentencepiece_test.py index 59f062e..d16bd35 100755 --- a/python/test/sentencepiece_test.py +++ b/python/test/sentencepiece_test.py @@ -21,6 +21,7 @@ import sentencepiece as spm import unittest import sys import os +import pickle from collections import defaultdict @@ -131,6 +132,19 @@ class TestSentencepieceProcessor(unittest.TestCase): text = text.encode('utf-8') self.assertEqual(text, self.jasp_.DecodeIds(ids)) + def test_pickle(self): + with open('sp.pickle', 'wb') as f: + pickle.dump(self.sp_, f) + + id1 = self.sp_.encode('hello world.', out_type=int) + + with open('sp.pickle', 'rb') as f: + sp = pickle.load(f) + + id2 = sp.encode('hello world.', out_type=int) + + self.assertEqual(id1, id2) + def test_train(self): spm.SentencePieceTrainer.Train('--input=' + os.path.join(data_dir, 'botchan.txt') + |